Spaces:
Running
Running
File size: 11,762 Bytes
605b3ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 |
import json
import logging
import os
from collections import OrderedDict
from decimal import Decimal
from pathlib import Path
from typing import Callable, Union
from typing import Tuple, Optional, List, Dict
import meeteval
import numpy as np
import pandas as pd
from meeteval.io.seglst import SegLstSegment
from meeteval.wer.wer.orc import OrcErrorRate
# this must be called before any other loggers are instantiated to take effect
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] [%(name)s] %(message)s')
def get_logger(name: str):
"""
All modules should use this function to get a logger.
This way, we ensure all loggers are instantiated after basicConfig() call and inherit the same config.
"""
return logging.getLogger(name)
_LOG = get_logger('wer')
def create_dummy_seg_list(session_id):
return meeteval.io.SegLST(
[{'session_id': session_id, 'start_time': Decimal(0), 'end_time': Decimal(0), 'speaker': '', 'words': ''}])
def calc_session_tcp_wer(ref, hyp, collar):
res = meeteval.wer.tcpwer(reference=ref, hypothesis=hyp, collar=collar)
res_df = pd.DataFrame.from_dict(res, orient='index').reset_index(names='session_id')
keys = ['error_rate', 'errors', 'length', 'insertions', 'deletions', 'substitutions',
'missed_speaker', 'falarm_speaker', 'scored_speaker', 'assignment']
return (res_df[['session_id'] + keys]
.rename(columns={k: 'tcp_' + k for k in keys})
.rename(columns={'tcp_error_rate': 'tcp_wer'}))
def calc_wer(
ref_seglst: SegLstSegment,
tcp_hyp_seglst: SegLstSegment,
collar: int = 5,
metrics_list: List[str] = None) -> pd.DataFrame:
"""
Calculates tcpWER and tcorcWER for each session in hypothesis files using meeteval, and saves the error
information to .json.
Text normalization is applied to both hypothesis and reference.
Args:
out_dir: the directory to save the ref.json reference transcript to (extracted from gt_utt_df).
tcp_wer_hyp_json: path to hypothesis .json file for tcpWER, or json structure.
tcorc_wer_hyp_json: path to hypothesis .json file for tcorcWER, or json structure.
gt_utt_df: dataframe of ground truth utterances. must include the sessions in the hypothesis files.
see load_data() function.
tn: text normalizer
collar: tolerance of tcpWER to temporal misalignment between hypothesis and reference.
save_visualizations: if True, save html visualizations of alignment between hyp and ref.
meeting_id_is_session_id: if True, the session_id in the hypothesis/ref files is the same as the meeting_id.
Returns:
wer_df: pd.DataFrame with columns -
'session_id' - same as in hypothesis files
'tcp_wer': tcpWER
'tcorc_wer': tcorcWER
... intermediate tcpWER/tcorcWER fields such as insertions/deletions. see in code.
"""
# json to SegLST structure (Segment-wise Long-form Speech Transcription annotation)
if len(tcp_hyp_seglst) == 0:
tcp_hyp_seglst = create_dummy_seg_list(ref_seglst.segments[0]['session_id'])
_LOG.warning(f"Empty tcp_wer_hyp_json, using dummy segment: {tcp_hyp_seglst.segments[0]}")
wers_to_concat = []
if "tcp_wer" in metrics_list:
tcp_wer_res = calc_session_tcp_wer(ref_seglst, tcp_hyp_seglst, collar)
wers_to_concat.append(tcp_wer_res.drop(columns='session_id'))
wer_df = pd.concat(wers_to_concat, axis=1)
wer_df['session_id'] = ref_seglst.segments[0]['session_id']
_LOG.debug('Done calculating WER')
_LOG.debug(f"\n{wer_df[['session_id', *metrics_list]]}")
return wer_df
def aggregate_wer_metrics(wer_df: pd.DataFrame, metrics_list: List[str]) -> Dict:
num_wer_df = wer_df._get_numeric_data()
metrics = num_wer_df.sum().to_dict(into=OrderedDict)
for metric in metrics_list:
mprefix, _ = metric.split("_", maxsplit=1)
metrics[mprefix + "_wer"] = metrics[mprefix + "_errors"] / metrics[mprefix + "_length"]
for k in ['missed_speaker', 'falarm_speaker', 'scored_speaker']:
# compute mean for this keys
key = f"{mprefix}_{k}"
new_key = f"{mprefix}_mean_{k}"
if key not in metrics:
continue
metrics[new_key] = metrics[key] / len(num_wer_df)
del metrics[key]
return metrics
def normalize_segment(segment: SegLstSegment, tn):
words = segment["words"]
words = tn(words)
segment["words"] = words
return segment
def assign_streams(tcorc_hyp_seglst):
tcorc_hyp_seglst = tcorc_hyp_seglst.groupby(key='speaker')
per_stream_list = [[] for _ in range(len(tcorc_hyp_seglst))]
for speaker_id, speaker_seglst in tcorc_hyp_seglst.items():
speaker_seglst = speaker_seglst.sorted(key='start_time')
for seg in speaker_seglst:
# check if current segment does not overlap with any of the segments in per_stream_list
for i in range(len(per_stream_list)):
if not any(seg['start_time'] < s['end_time'] and seg['end_time'] > s['start_time'] for s in
per_stream_list[i]):
seg['speaker'] = i
per_stream_list[i].append(seg)
break
else:
raise ValueError('No stream found for segment')
tcorc_hyp_seglst = meeteval.io.SegLST([seg for stream in per_stream_list for seg in stream]).sorted(
key='start_time')
return tcorc_hyp_seglst
def filter_empty_segments(seg_lst):
return seg_lst.filter(lambda seg: seg['words'] != '')
def find_first_non_overlapping_segment_streams(per_speaker_groups, per_speaker_vad_masks):
for speaker_id, speaker_seglst in per_speaker_groups.items():
for other_speaker_id, other_speaker_seglst in per_speaker_groups.items():
if speaker_id != other_speaker_id:
vad_mask_merged = per_speaker_vad_masks[speaker_id] & per_speaker_vad_masks[other_speaker_id]
if not vad_mask_merged.any():
return (speaker_id, other_speaker_id)
def change_speaker_id(segment, speaker_id):
segment['speaker'] = speaker_id
return segment
def merge_streams(tcorc_hyp_seglst):
per_speaker_groups = tcorc_hyp_seglst.groupby(key='speaker')
# create per speaker vad masks
per_speaker_vad_masks = {}
for speaker_id, speaker_seglst in per_speaker_groups.items():
per_speaker_vad_masks[speaker_id] = create_vad_mask(speaker_seglst, time_step=0.01)
longest_mask = max(len(mask) for mask in per_speaker_vad_masks.values())
# pad all masks to the same length
for speaker_id, mask in per_speaker_vad_masks.items():
per_speaker_vad_masks[speaker_id] = np.pad(mask, (0, longest_mask - len(mask)))
# recursively merge all pairs of speakers that have no overlapping vad masks
while True:
res = find_first_non_overlapping_segment_streams(per_speaker_groups, per_speaker_vad_masks)
if res is None:
break
speaker_id, other_speaker_id = res
per_speaker_groups[speaker_id] = per_speaker_groups[speaker_id] + per_speaker_groups[other_speaker_id].map(
lambda seg: change_speaker_id(seg, speaker_id))
per_speaker_vad_masks[speaker_id] = per_speaker_vad_masks[speaker_id] | per_speaker_vad_masks[other_speaker_id]
del per_speaker_groups[other_speaker_id]
del per_speaker_vad_masks[other_speaker_id]
tcorc_hyp_seglst = meeteval.io.SegLST(
[seg for speaker_seglst in per_speaker_groups.values() for seg in speaker_seglst]).sorted(key='start_time')
return tcorc_hyp_seglst
def normalize_segment(segment: SegLstSegment, tn):
words = segment["words"]
words = tn(words)
segment["words"] = words
return segment
def create_vad_mask(segments, time_step=0.1, total_duration=None):
"""
Create a VAD mask for the given segments.
:param segments: List of segments, each containing 'start_time' and 'end_time'
:param time_step: The resolution of the VAD mask in seconds (default: 100ms)
:param total_duration: Optionally specify the total duration to create the mask.
If not provided, the mask will be generated based on the maximum end time of the segments.
:return: VAD mask as a numpy array, where 1 represents voice activity and 0 represents silence.
"""
# Find the total duration if not provided
if total_duration is None:
total_duration = max(seg['end_time'] for seg in segments)
# Initialize VAD mask as zeros (silence)
mask_length = int(float(total_duration) / time_step) + 1
vad_mask = np.zeros(mask_length, dtype=bool)
# Iterate over segments and mark the corresponding times as active (1)
for seg in segments:
start_idx = int(float(seg['start_time']) / time_step)
end_idx = int(float(seg['end_time']) / time_step)
vad_mask[start_idx:end_idx] = 1
return vad_mask
def find_group_splits(vad, group_duration=30, time_step=0.1):
non_active_indices = np.argwhere(~vad).squeeze(axis=-1)
splits = []
group_shift = group_duration / time_step
next_offset = group_shift
for i in non_active_indices:
if i >= next_offset:
splits.append(i)
next_offset = i + group_shift
return splits
def map_utterance_to_split(utterance_start_time, splits):
for i, split in enumerate(splits):
if utterance_start_time < split:
return i
return len(splits)
def agregate_errors_across_groups(res, session_id):
overall_error_number = sum([group.errors for group in res.values()])
overall_length = sum([group.length for group in res.values()])
overall_errors = {
'error_rate': overall_error_number / overall_length,
'errors': overall_error_number,
'length': overall_length,
'insertions': sum([group.insertions for group in res.values()]),
'deletions': sum([group.deletions for group in res.values()]),
'substitutions': sum([group.substitutions for group in res.values()]),
'assignment': []
}
for group in res.values():
overall_errors['assignment'].extend(list(group.assignment))
overall_errors['assignment'] = tuple(overall_errors['assignment'])
res = {session_id: OrcErrorRate(errors=overall_errors["errors"],
length=overall_errors["length"],
insertions=overall_errors["insertions"],
deletions=overall_errors["deletions"],
substitutions=overall_errors["substitutions"],
hypothesis_self_overlap=None,
reference_self_overlap=None,
assignment=overall_errors["assignment"])}
return res
def aggregate_wer_metrics(wer_df: pd.DataFrame, metrics_list: List[str]) -> Dict:
num_wer_df = wer_df._get_numeric_data()
metrics = num_wer_df.sum().to_dict(into=OrderedDict)
for metric in metrics_list:
mprefix, _ = metric.split("_", maxsplit=1)
metrics[mprefix + "_wer"] = metrics[mprefix + "_errors"] / metrics[mprefix + "_length"]
for k in ['missed_speaker', 'falarm_speaker', 'scored_speaker']:
# compute mean for this keys
key = f"{mprefix}_{k}"
new_key = f"{mprefix}_mean_{k}"
if key not in metrics:
continue
metrics[new_key] = metrics[key] / len(num_wer_df)
del metrics[key]
return metrics
|