EMMA_leaderboard / utils.py
Lakoc
Initial commit
605b3ec
raw
history blame
11.8 kB
import json
import logging
import os
from collections import OrderedDict
from decimal import Decimal
from pathlib import Path
from typing import Callable, Union
from typing import Tuple, Optional, List, Dict
import meeteval
import numpy as np
import pandas as pd
from meeteval.io.seglst import SegLstSegment
from meeteval.wer.wer.orc import OrcErrorRate
# this must be called before any other loggers are instantiated to take effect
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] [%(name)s] %(message)s')
def get_logger(name: str):
"""
All modules should use this function to get a logger.
This way, we ensure all loggers are instantiated after basicConfig() call and inherit the same config.
"""
return logging.getLogger(name)
_LOG = get_logger('wer')
def create_dummy_seg_list(session_id):
return meeteval.io.SegLST(
[{'session_id': session_id, 'start_time': Decimal(0), 'end_time': Decimal(0), 'speaker': '', 'words': ''}])
def calc_session_tcp_wer(ref, hyp, collar):
res = meeteval.wer.tcpwer(reference=ref, hypothesis=hyp, collar=collar)
res_df = pd.DataFrame.from_dict(res, orient='index').reset_index(names='session_id')
keys = ['error_rate', 'errors', 'length', 'insertions', 'deletions', 'substitutions',
'missed_speaker', 'falarm_speaker', 'scored_speaker', 'assignment']
return (res_df[['session_id'] + keys]
.rename(columns={k: 'tcp_' + k for k in keys})
.rename(columns={'tcp_error_rate': 'tcp_wer'}))
def calc_wer(
ref_seglst: SegLstSegment,
tcp_hyp_seglst: SegLstSegment,
collar: int = 5,
metrics_list: List[str] = None) -> pd.DataFrame:
"""
Calculates tcpWER and tcorcWER for each session in hypothesis files using meeteval, and saves the error
information to .json.
Text normalization is applied to both hypothesis and reference.
Args:
out_dir: the directory to save the ref.json reference transcript to (extracted from gt_utt_df).
tcp_wer_hyp_json: path to hypothesis .json file for tcpWER, or json structure.
tcorc_wer_hyp_json: path to hypothesis .json file for tcorcWER, or json structure.
gt_utt_df: dataframe of ground truth utterances. must include the sessions in the hypothesis files.
see load_data() function.
tn: text normalizer
collar: tolerance of tcpWER to temporal misalignment between hypothesis and reference.
save_visualizations: if True, save html visualizations of alignment between hyp and ref.
meeting_id_is_session_id: if True, the session_id in the hypothesis/ref files is the same as the meeting_id.
Returns:
wer_df: pd.DataFrame with columns -
'session_id' - same as in hypothesis files
'tcp_wer': tcpWER
'tcorc_wer': tcorcWER
... intermediate tcpWER/tcorcWER fields such as insertions/deletions. see in code.
"""
# json to SegLST structure (Segment-wise Long-form Speech Transcription annotation)
if len(tcp_hyp_seglst) == 0:
tcp_hyp_seglst = create_dummy_seg_list(ref_seglst.segments[0]['session_id'])
_LOG.warning(f"Empty tcp_wer_hyp_json, using dummy segment: {tcp_hyp_seglst.segments[0]}")
wers_to_concat = []
if "tcp_wer" in metrics_list:
tcp_wer_res = calc_session_tcp_wer(ref_seglst, tcp_hyp_seglst, collar)
wers_to_concat.append(tcp_wer_res.drop(columns='session_id'))
wer_df = pd.concat(wers_to_concat, axis=1)
wer_df['session_id'] = ref_seglst.segments[0]['session_id']
_LOG.debug('Done calculating WER')
_LOG.debug(f"\n{wer_df[['session_id', *metrics_list]]}")
return wer_df
def aggregate_wer_metrics(wer_df: pd.DataFrame, metrics_list: List[str]) -> Dict:
num_wer_df = wer_df._get_numeric_data()
metrics = num_wer_df.sum().to_dict(into=OrderedDict)
for metric in metrics_list:
mprefix, _ = metric.split("_", maxsplit=1)
metrics[mprefix + "_wer"] = metrics[mprefix + "_errors"] / metrics[mprefix + "_length"]
for k in ['missed_speaker', 'falarm_speaker', 'scored_speaker']:
# compute mean for this keys
key = f"{mprefix}_{k}"
new_key = f"{mprefix}_mean_{k}"
if key not in metrics:
continue
metrics[new_key] = metrics[key] / len(num_wer_df)
del metrics[key]
return metrics
def normalize_segment(segment: SegLstSegment, tn):
words = segment["words"]
words = tn(words)
segment["words"] = words
return segment
def assign_streams(tcorc_hyp_seglst):
tcorc_hyp_seglst = tcorc_hyp_seglst.groupby(key='speaker')
per_stream_list = [[] for _ in range(len(tcorc_hyp_seglst))]
for speaker_id, speaker_seglst in tcorc_hyp_seglst.items():
speaker_seglst = speaker_seglst.sorted(key='start_time')
for seg in speaker_seglst:
# check if current segment does not overlap with any of the segments in per_stream_list
for i in range(len(per_stream_list)):
if not any(seg['start_time'] < s['end_time'] and seg['end_time'] > s['start_time'] for s in
per_stream_list[i]):
seg['speaker'] = i
per_stream_list[i].append(seg)
break
else:
raise ValueError('No stream found for segment')
tcorc_hyp_seglst = meeteval.io.SegLST([seg for stream in per_stream_list for seg in stream]).sorted(
key='start_time')
return tcorc_hyp_seglst
def filter_empty_segments(seg_lst):
return seg_lst.filter(lambda seg: seg['words'] != '')
def find_first_non_overlapping_segment_streams(per_speaker_groups, per_speaker_vad_masks):
for speaker_id, speaker_seglst in per_speaker_groups.items():
for other_speaker_id, other_speaker_seglst in per_speaker_groups.items():
if speaker_id != other_speaker_id:
vad_mask_merged = per_speaker_vad_masks[speaker_id] & per_speaker_vad_masks[other_speaker_id]
if not vad_mask_merged.any():
return (speaker_id, other_speaker_id)
def change_speaker_id(segment, speaker_id):
segment['speaker'] = speaker_id
return segment
def merge_streams(tcorc_hyp_seglst):
per_speaker_groups = tcorc_hyp_seglst.groupby(key='speaker')
# create per speaker vad masks
per_speaker_vad_masks = {}
for speaker_id, speaker_seglst in per_speaker_groups.items():
per_speaker_vad_masks[speaker_id] = create_vad_mask(speaker_seglst, time_step=0.01)
longest_mask = max(len(mask) for mask in per_speaker_vad_masks.values())
# pad all masks to the same length
for speaker_id, mask in per_speaker_vad_masks.items():
per_speaker_vad_masks[speaker_id] = np.pad(mask, (0, longest_mask - len(mask)))
# recursively merge all pairs of speakers that have no overlapping vad masks
while True:
res = find_first_non_overlapping_segment_streams(per_speaker_groups, per_speaker_vad_masks)
if res is None:
break
speaker_id, other_speaker_id = res
per_speaker_groups[speaker_id] = per_speaker_groups[speaker_id] + per_speaker_groups[other_speaker_id].map(
lambda seg: change_speaker_id(seg, speaker_id))
per_speaker_vad_masks[speaker_id] = per_speaker_vad_masks[speaker_id] | per_speaker_vad_masks[other_speaker_id]
del per_speaker_groups[other_speaker_id]
del per_speaker_vad_masks[other_speaker_id]
tcorc_hyp_seglst = meeteval.io.SegLST(
[seg for speaker_seglst in per_speaker_groups.values() for seg in speaker_seglst]).sorted(key='start_time')
return tcorc_hyp_seglst
def normalize_segment(segment: SegLstSegment, tn):
words = segment["words"]
words = tn(words)
segment["words"] = words
return segment
def create_vad_mask(segments, time_step=0.1, total_duration=None):
"""
Create a VAD mask for the given segments.
:param segments: List of segments, each containing 'start_time' and 'end_time'
:param time_step: The resolution of the VAD mask in seconds (default: 100ms)
:param total_duration: Optionally specify the total duration to create the mask.
If not provided, the mask will be generated based on the maximum end time of the segments.
:return: VAD mask as a numpy array, where 1 represents voice activity and 0 represents silence.
"""
# Find the total duration if not provided
if total_duration is None:
total_duration = max(seg['end_time'] for seg in segments)
# Initialize VAD mask as zeros (silence)
mask_length = int(float(total_duration) / time_step) + 1
vad_mask = np.zeros(mask_length, dtype=bool)
# Iterate over segments and mark the corresponding times as active (1)
for seg in segments:
start_idx = int(float(seg['start_time']) / time_step)
end_idx = int(float(seg['end_time']) / time_step)
vad_mask[start_idx:end_idx] = 1
return vad_mask
def find_group_splits(vad, group_duration=30, time_step=0.1):
non_active_indices = np.argwhere(~vad).squeeze(axis=-1)
splits = []
group_shift = group_duration / time_step
next_offset = group_shift
for i in non_active_indices:
if i >= next_offset:
splits.append(i)
next_offset = i + group_shift
return splits
def map_utterance_to_split(utterance_start_time, splits):
for i, split in enumerate(splits):
if utterance_start_time < split:
return i
return len(splits)
def agregate_errors_across_groups(res, session_id):
overall_error_number = sum([group.errors for group in res.values()])
overall_length = sum([group.length for group in res.values()])
overall_errors = {
'error_rate': overall_error_number / overall_length,
'errors': overall_error_number,
'length': overall_length,
'insertions': sum([group.insertions for group in res.values()]),
'deletions': sum([group.deletions for group in res.values()]),
'substitutions': sum([group.substitutions for group in res.values()]),
'assignment': []
}
for group in res.values():
overall_errors['assignment'].extend(list(group.assignment))
overall_errors['assignment'] = tuple(overall_errors['assignment'])
res = {session_id: OrcErrorRate(errors=overall_errors["errors"],
length=overall_errors["length"],
insertions=overall_errors["insertions"],
deletions=overall_errors["deletions"],
substitutions=overall_errors["substitutions"],
hypothesis_self_overlap=None,
reference_self_overlap=None,
assignment=overall_errors["assignment"])}
return res
def aggregate_wer_metrics(wer_df: pd.DataFrame, metrics_list: List[str]) -> Dict:
num_wer_df = wer_df._get_numeric_data()
metrics = num_wer_df.sum().to_dict(into=OrderedDict)
for metric in metrics_list:
mprefix, _ = metric.split("_", maxsplit=1)
metrics[mprefix + "_wer"] = metrics[mprefix + "_errors"] / metrics[mprefix + "_length"]
for k in ['missed_speaker', 'falarm_speaker', 'scored_speaker']:
# compute mean for this keys
key = f"{mprefix}_{k}"
new_key = f"{mprefix}_mean_{k}"
if key not in metrics:
continue
metrics[new_key] = metrics[key] / len(num_wer_df)
del metrics[key]
return metrics