File size: 8,575 Bytes
6cfcfea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 |
import os
import sys
import ast
import torch
import itertools
import collections
sys.path.append(os.getcwd())
from main.library.speaker_diarization.speechbrain import if_main_process, ddp_barrier
from main.library.speaker_diarization.features import register_checkpoint_hooks, mark_as_saver, mark_as_loader
DEFAULT_UNK = "<unk>"
DEFAULT_BOS = "<bos>"
DEFAULT_EOS = "<eos>"
DEFAULT_BLANK = "<blank>"
@register_checkpoint_hooks
class CategoricalEncoder:
VALUE_SEPARATOR = " => "
EXTRAS_SEPARATOR = "================\n"
def __init__(self, starting_index=0, **special_labels):
self.lab2ind = {}
self.ind2lab = {}
self.starting_index = starting_index
self.handle_special_labels(special_labels)
def handle_special_labels(self, special_labels):
if "unk_label" in special_labels: self.add_unk(special_labels["unk_label"])
def __len__(self):
return len(self.lab2ind)
@classmethod
def from_saved(cls, path):
obj = cls()
obj.load(path)
return obj
def update_from_iterable(self, iterable, sequence_input=False):
label_iterator = itertools.chain.from_iterable(iterable) if sequence_input else iter(iterable)
for label in label_iterator:
self.ensure_label(label)
def update_from_didataset(self, didataset, output_key, sequence_input=False):
with didataset.output_keys_as([output_key]):
self.update_from_iterable((data_point[output_key] for data_point in didataset), sequence_input=sequence_input)
def limited_labelset_from_iterable(self, iterable, sequence_input=False, n_most_common=None, min_count=1):
label_iterator = itertools.chain.from_iterable(iterable) if sequence_input else iter(iterable)
counts = collections.Counter(label_iterator)
for label, count in counts.most_common(n_most_common):
if count < min_count: break
self.add_label(label)
return counts
def load_or_create(self, path, from_iterables=[], from_didatasets=[], sequence_input=False, output_key=None, special_labels={}):
try:
if if_main_process():
if not self.load_if_possible(path):
for iterable in from_iterables:
self.update_from_iterable(iterable, sequence_input)
for didataset in from_didatasets:
if output_key is None: raise ValueError
self.update_from_didataset(didataset, output_key, sequence_input)
self.handle_special_labels(special_labels)
self.save(path)
finally:
ddp_barrier()
self.load(path)
def add_label(self, label):
if label in self.lab2ind: raise KeyError
index = self._next_index()
self.lab2ind[label] = index
self.ind2lab[index] = label
return index
def ensure_label(self, label):
if label in self.lab2ind: return self.lab2ind[label]
else: return self.add_label(label)
def insert_label(self, label, index):
if label in self.lab2ind: raise KeyError
else: self.enforce_label(label, index)
def enforce_label(self, label, index):
index = int(index)
if label in self.lab2ind:
if index == self.lab2ind[label]: return
else: del self.ind2lab[self.lab2ind[label]]
if index in self.ind2lab:
saved_label = self.ind2lab[index]
moving_other = True
else: moving_other = False
self.lab2ind[label] = index
self.ind2lab[index] = label
if moving_other:
new_index = self._next_index()
self.lab2ind[saved_label] = new_index
self.ind2lab[new_index] = saved_label
def add_unk(self, unk_label=DEFAULT_UNK):
self.unk_label = unk_label
return self.add_label(unk_label)
def _next_index(self):
index = self.starting_index
while index in self.ind2lab:
index += 1
return index
def is_continuous(self):
indices = sorted(self.ind2lab.keys())
return self.starting_index in indices and all(j - i == 1 for i, j in zip(indices[:-1], indices[1:]))
def encode_label(self, label, allow_unk=True):
self._assert_len()
try:
return self.lab2ind[label]
except KeyError:
if hasattr(self, "unk_label") and allow_unk: return self.lab2ind[self.unk_label]
elif hasattr(self, "unk_label") and not allow_unk: raise KeyError
elif not hasattr(self, "unk_label") and allow_unk: raise KeyError
else: raise KeyError
def encode_label_torch(self, label, allow_unk=True):
return torch.LongTensor([self.encode_label(label, allow_unk)])
def encode_sequence(self, sequence, allow_unk=True):
self._assert_len()
return [self.encode_label(label, allow_unk) for label in sequence]
def encode_sequence_torch(self, sequence, allow_unk=True):
return torch.LongTensor([self.encode_label(label, allow_unk) for label in sequence])
def decode_torch(self, x):
self._assert_len()
decoded = []
if x.ndim == 1:
for element in x:
decoded.append(self.ind2lab[int(element)])
else:
for subtensor in x:
decoded.append(self.decode_torch(subtensor))
return decoded
def decode_ndim(self, x):
self._assert_len()
try:
decoded = []
for subtensor in x:
decoded.append(self.decode_ndim(subtensor))
return decoded
except TypeError:
return self.ind2lab[int(x)]
@mark_as_saver
def save(self, path):
self._save_literal(path, self.lab2ind, self._get_extras())
def load(self, path):
lab2ind, ind2lab, extras = self._load_literal(path)
self.lab2ind = lab2ind
self.ind2lab = ind2lab
self._set_extras(extras)
@mark_as_loader
def load_if_possible(self, path, end_of_epoch=False):
del end_of_epoch
try:
self.load(path)
except FileNotFoundError:
return False
except (ValueError, SyntaxError):
return False
return True
def expect_len(self, expected_len):
self.expected_len = expected_len
def ignore_len(self):
self.expected_len = None
def _assert_len(self):
if hasattr(self, "expected_len"):
if self.expected_len is None: return
if len(self) != self.expected_len: raise RuntimeError
else:
self.ignore_len()
return
def _get_extras(self):
extras = {"starting_index": self.starting_index}
if hasattr(self, "unk_label"): extras["unk_label"] = self.unk_label
return extras
def _set_extras(self, extras):
if "unk_label" in extras: self.unk_label = extras["unk_label"]
self.starting_index = extras["starting_index"]
@staticmethod
def _save_literal(path, lab2ind, extras):
with open(path, "w", encoding="utf-8") as f:
for label, ind in lab2ind.items():
f.write(repr(label) + CategoricalEncoder.VALUE_SEPARATOR + str(ind) + "\n")
f.write(CategoricalEncoder.EXTRAS_SEPARATOR)
for key, value in extras.items():
f.write(repr(key) + CategoricalEncoder.VALUE_SEPARATOR + repr(value) + "\n")
f.flush()
@staticmethod
def _load_literal(path):
lab2ind, ind2lab, extras = {}, {}, {}
with open(path, encoding="utf-8") as f:
for line in f:
if line == CategoricalEncoder.EXTRAS_SEPARATOR: break
literal, ind = line.strip().split(CategoricalEncoder.VALUE_SEPARATOR, maxsplit=1)
label = ast.literal_eval(literal)
lab2ind[label] = int(ind)
ind2lab[ind] = label
for line in f:
literal_key, literal_value = line.strip().split(CategoricalEncoder.VALUE_SEPARATOR, maxsplit=1)
extras[ast.literal_eval(literal_key)] = ast.literal_eval(literal_value)
return lab2ind, ind2lab, extras |