import json import logging import os from collections import OrderedDict from decimal import Decimal from pathlib import Path from typing import Callable, Union from typing import Tuple, Optional, List, Dict import meeteval import numpy as np import pandas as pd from meeteval.io.seglst import SegLstSegment from meeteval.wer.wer.orc import OrcErrorRate # this must be called before any other loggers are instantiated to take effect logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] [%(name)s] %(message)s') def get_logger(name: str): """ All modules should use this function to get a logger. This way, we ensure all loggers are instantiated after basicConfig() call and inherit the same config. """ return logging.getLogger(name) _LOG = get_logger('wer') def create_dummy_seg_list(session_id): return meeteval.io.SegLST( [{'session_id': session_id, 'start_time': Decimal(0), 'end_time': Decimal(0), 'speaker': '', 'words': ''}]) def calc_session_tcp_wer(ref, hyp, collar): res = meeteval.wer.tcpwer(reference=ref, hypothesis=hyp, collar=collar) res_df = pd.DataFrame.from_dict(res, orient='index').reset_index(names='session_id') keys = ['error_rate', 'errors', 'length', 'insertions', 'deletions', 'substitutions', 'missed_speaker', 'falarm_speaker', 'scored_speaker', 'assignment'] return (res_df[['session_id'] + keys] .rename(columns={k: 'tcp_' + k for k in keys}) .rename(columns={'tcp_error_rate': 'tcp_wer'})) def calc_wer( ref_seglst: SegLstSegment, tcp_hyp_seglst: SegLstSegment, collar: int = 5, metrics_list: List[str] = None) -> pd.DataFrame: """ Calculates tcpWER and tcorcWER for each session in hypothesis files using meeteval, and saves the error information to .json. Text normalization is applied to both hypothesis and reference. Args: out_dir: the directory to save the ref.json reference transcript to (extracted from gt_utt_df). tcp_wer_hyp_json: path to hypothesis .json file for tcpWER, or json structure. tcorc_wer_hyp_json: path to hypothesis .json file for tcorcWER, or json structure. gt_utt_df: dataframe of ground truth utterances. must include the sessions in the hypothesis files. see load_data() function. tn: text normalizer collar: tolerance of tcpWER to temporal misalignment between hypothesis and reference. save_visualizations: if True, save html visualizations of alignment between hyp and ref. meeting_id_is_session_id: if True, the session_id in the hypothesis/ref files is the same as the meeting_id. Returns: wer_df: pd.DataFrame with columns - 'session_id' - same as in hypothesis files 'tcp_wer': tcpWER 'tcorc_wer': tcorcWER ... intermediate tcpWER/tcorcWER fields such as insertions/deletions. see in code. """ # json to SegLST structure (Segment-wise Long-form Speech Transcription annotation) if len(tcp_hyp_seglst) == 0: tcp_hyp_seglst = create_dummy_seg_list(ref_seglst.segments[0]['session_id']) _LOG.warning(f"Empty tcp_wer_hyp_json, using dummy segment: {tcp_hyp_seglst.segments[0]}") wers_to_concat = [] if "tcp_wer" in metrics_list: tcp_wer_res = calc_session_tcp_wer(ref_seglst, tcp_hyp_seglst, collar) wers_to_concat.append(tcp_wer_res.drop(columns='session_id')) wer_df = pd.concat(wers_to_concat, axis=1) wer_df['session_id'] = ref_seglst.segments[0]['session_id'] _LOG.debug('Done calculating WER') _LOG.debug(f"\n{wer_df[['session_id', *metrics_list]]}") return wer_df def aggregate_wer_metrics(wer_df: pd.DataFrame, metrics_list: List[str]) -> Dict: num_wer_df = wer_df._get_numeric_data() metrics = num_wer_df.sum().to_dict(into=OrderedDict) for metric in metrics_list: mprefix, _ = metric.split("_", maxsplit=1) metrics[mprefix + "_wer"] = metrics[mprefix + "_errors"] / metrics[mprefix + "_length"] for k in ['missed_speaker', 'falarm_speaker', 'scored_speaker']: # compute mean for this keys key = f"{mprefix}_{k}" new_key = f"{mprefix}_mean_{k}" if key not in metrics: continue metrics[new_key] = metrics[key] / len(num_wer_df) del metrics[key] return metrics def normalize_segment(segment: SegLstSegment, tn): words = segment["words"] words = tn(words) segment["words"] = words return segment def assign_streams(tcorc_hyp_seglst): tcorc_hyp_seglst = tcorc_hyp_seglst.groupby(key='speaker') per_stream_list = [[] for _ in range(len(tcorc_hyp_seglst))] for speaker_id, speaker_seglst in tcorc_hyp_seglst.items(): speaker_seglst = speaker_seglst.sorted(key='start_time') for seg in speaker_seglst: # check if current segment does not overlap with any of the segments in per_stream_list for i in range(len(per_stream_list)): if not any(seg['start_time'] < s['end_time'] and seg['end_time'] > s['start_time'] for s in per_stream_list[i]): seg['speaker'] = i per_stream_list[i].append(seg) break else: raise ValueError('No stream found for segment') tcorc_hyp_seglst = meeteval.io.SegLST([seg for stream in per_stream_list for seg in stream]).sorted( key='start_time') return tcorc_hyp_seglst def filter_empty_segments(seg_lst): return seg_lst.filter(lambda seg: seg['words'] != '') def find_first_non_overlapping_segment_streams(per_speaker_groups, per_speaker_vad_masks): for speaker_id, speaker_seglst in per_speaker_groups.items(): for other_speaker_id, other_speaker_seglst in per_speaker_groups.items(): if speaker_id != other_speaker_id: vad_mask_merged = per_speaker_vad_masks[speaker_id] & per_speaker_vad_masks[other_speaker_id] if not vad_mask_merged.any(): return (speaker_id, other_speaker_id) def change_speaker_id(segment, speaker_id): segment['speaker'] = speaker_id return segment def merge_streams(tcorc_hyp_seglst): per_speaker_groups = tcorc_hyp_seglst.groupby(key='speaker') # create per speaker vad masks per_speaker_vad_masks = {} for speaker_id, speaker_seglst in per_speaker_groups.items(): per_speaker_vad_masks[speaker_id] = create_vad_mask(speaker_seglst, time_step=0.01) longest_mask = max(len(mask) for mask in per_speaker_vad_masks.values()) # pad all masks to the same length for speaker_id, mask in per_speaker_vad_masks.items(): per_speaker_vad_masks[speaker_id] = np.pad(mask, (0, longest_mask - len(mask))) # recursively merge all pairs of speakers that have no overlapping vad masks while True: res = find_first_non_overlapping_segment_streams(per_speaker_groups, per_speaker_vad_masks) if res is None: break speaker_id, other_speaker_id = res per_speaker_groups[speaker_id] = per_speaker_groups[speaker_id] + per_speaker_groups[other_speaker_id].map( lambda seg: change_speaker_id(seg, speaker_id)) per_speaker_vad_masks[speaker_id] = per_speaker_vad_masks[speaker_id] | per_speaker_vad_masks[other_speaker_id] del per_speaker_groups[other_speaker_id] del per_speaker_vad_masks[other_speaker_id] tcorc_hyp_seglst = meeteval.io.SegLST( [seg for speaker_seglst in per_speaker_groups.values() for seg in speaker_seglst]).sorted(key='start_time') return tcorc_hyp_seglst def normalize_segment(segment: SegLstSegment, tn): words = segment["words"] words = tn(words) segment["words"] = words return segment def create_vad_mask(segments, time_step=0.1, total_duration=None): """ Create a VAD mask for the given segments. :param segments: List of segments, each containing 'start_time' and 'end_time' :param time_step: The resolution of the VAD mask in seconds (default: 100ms) :param total_duration: Optionally specify the total duration to create the mask. If not provided, the mask will be generated based on the maximum end time of the segments. :return: VAD mask as a numpy array, where 1 represents voice activity and 0 represents silence. """ # Find the total duration if not provided if total_duration is None: total_duration = max(seg['end_time'] for seg in segments) # Initialize VAD mask as zeros (silence) mask_length = int(float(total_duration) / time_step) + 1 vad_mask = np.zeros(mask_length, dtype=bool) # Iterate over segments and mark the corresponding times as active (1) for seg in segments: start_idx = int(float(seg['start_time']) / time_step) end_idx = int(float(seg['end_time']) / time_step) vad_mask[start_idx:end_idx] = 1 return vad_mask def find_group_splits(vad, group_duration=30, time_step=0.1): non_active_indices = np.argwhere(~vad).squeeze(axis=-1) splits = [] group_shift = group_duration / time_step next_offset = group_shift for i in non_active_indices: if i >= next_offset: splits.append(i) next_offset = i + group_shift return splits def map_utterance_to_split(utterance_start_time, splits): for i, split in enumerate(splits): if utterance_start_time < split: return i return len(splits) def agregate_errors_across_groups(res, session_id): overall_error_number = sum([group.errors for group in res.values()]) overall_length = sum([group.length for group in res.values()]) overall_errors = { 'error_rate': overall_error_number / overall_length, 'errors': overall_error_number, 'length': overall_length, 'insertions': sum([group.insertions for group in res.values()]), 'deletions': sum([group.deletions for group in res.values()]), 'substitutions': sum([group.substitutions for group in res.values()]), 'assignment': [] } for group in res.values(): overall_errors['assignment'].extend(list(group.assignment)) overall_errors['assignment'] = tuple(overall_errors['assignment']) res = {session_id: OrcErrorRate(errors=overall_errors["errors"], length=overall_errors["length"], insertions=overall_errors["insertions"], deletions=overall_errors["deletions"], substitutions=overall_errors["substitutions"], hypothesis_self_overlap=None, reference_self_overlap=None, assignment=overall_errors["assignment"])} return res def aggregate_wer_metrics(wer_df: pd.DataFrame, metrics_list: List[str]) -> Dict: num_wer_df = wer_df._get_numeric_data() metrics = num_wer_df.sum().to_dict(into=OrderedDict) for metric in metrics_list: mprefix, _ = metric.split("_", maxsplit=1) metrics[mprefix + "_wer"] = metrics[mprefix + "_errors"] / metrics[mprefix + "_length"] for k in ['missed_speaker', 'falarm_speaker', 'scored_speaker']: # compute mean for this keys key = f"{mprefix}_{k}" new_key = f"{mprefix}_mean_{k}" if key not in metrics: continue metrics[new_key] = metrics[key] / len(num_wer_df) del metrics[key] return metrics