|
import os
|
|
import sys
|
|
import ast
|
|
import torch
|
|
import itertools
|
|
import collections
|
|
|
|
sys.path.append(os.getcwd())
|
|
|
|
from main.library.speaker_diarization.speechbrain import if_main_process, ddp_barrier
|
|
from main.library.speaker_diarization.features import register_checkpoint_hooks, mark_as_saver, mark_as_loader
|
|
|
|
DEFAULT_UNK = "<unk>"
|
|
DEFAULT_BOS = "<bos>"
|
|
DEFAULT_EOS = "<eos>"
|
|
DEFAULT_BLANK = "<blank>"
|
|
|
|
@register_checkpoint_hooks
|
|
class CategoricalEncoder:
|
|
VALUE_SEPARATOR = " => "
|
|
EXTRAS_SEPARATOR = "================\n"
|
|
|
|
def __init__(self, starting_index=0, **special_labels):
|
|
self.lab2ind = {}
|
|
self.ind2lab = {}
|
|
self.starting_index = starting_index
|
|
self.handle_special_labels(special_labels)
|
|
|
|
def handle_special_labels(self, special_labels):
|
|
if "unk_label" in special_labels: self.add_unk(special_labels["unk_label"])
|
|
|
|
def __len__(self):
|
|
return len(self.lab2ind)
|
|
|
|
@classmethod
|
|
def from_saved(cls, path):
|
|
obj = cls()
|
|
obj.load(path)
|
|
return obj
|
|
|
|
def update_from_iterable(self, iterable, sequence_input=False):
|
|
label_iterator = itertools.chain.from_iterable(iterable) if sequence_input else iter(iterable)
|
|
for label in label_iterator:
|
|
self.ensure_label(label)
|
|
|
|
def update_from_didataset(self, didataset, output_key, sequence_input=False):
|
|
with didataset.output_keys_as([output_key]):
|
|
self.update_from_iterable((data_point[output_key] for data_point in didataset), sequence_input=sequence_input)
|
|
|
|
def limited_labelset_from_iterable(self, iterable, sequence_input=False, n_most_common=None, min_count=1):
|
|
label_iterator = itertools.chain.from_iterable(iterable) if sequence_input else iter(iterable)
|
|
counts = collections.Counter(label_iterator)
|
|
|
|
for label, count in counts.most_common(n_most_common):
|
|
if count < min_count: break
|
|
self.add_label(label)
|
|
|
|
return counts
|
|
|
|
def load_or_create(self, path, from_iterables=[], from_didatasets=[], sequence_input=False, output_key=None, special_labels={}):
|
|
try:
|
|
if if_main_process():
|
|
if not self.load_if_possible(path):
|
|
for iterable in from_iterables:
|
|
self.update_from_iterable(iterable, sequence_input)
|
|
|
|
for didataset in from_didatasets:
|
|
if output_key is None: raise ValueError
|
|
self.update_from_didataset(didataset, output_key, sequence_input)
|
|
|
|
self.handle_special_labels(special_labels)
|
|
self.save(path)
|
|
finally:
|
|
ddp_barrier()
|
|
self.load(path)
|
|
|
|
def add_label(self, label):
|
|
if label in self.lab2ind: raise KeyError
|
|
index = self._next_index()
|
|
|
|
self.lab2ind[label] = index
|
|
self.ind2lab[index] = label
|
|
|
|
return index
|
|
|
|
def ensure_label(self, label):
|
|
if label in self.lab2ind: return self.lab2ind[label]
|
|
else: return self.add_label(label)
|
|
|
|
def insert_label(self, label, index):
|
|
if label in self.lab2ind: raise KeyError
|
|
else: self.enforce_label(label, index)
|
|
|
|
def enforce_label(self, label, index):
|
|
index = int(index)
|
|
|
|
if label in self.lab2ind:
|
|
if index == self.lab2ind[label]: return
|
|
else: del self.ind2lab[self.lab2ind[label]]
|
|
|
|
if index in self.ind2lab:
|
|
saved_label = self.ind2lab[index]
|
|
moving_other = True
|
|
else: moving_other = False
|
|
|
|
self.lab2ind[label] = index
|
|
self.ind2lab[index] = label
|
|
|
|
if moving_other:
|
|
new_index = self._next_index()
|
|
self.lab2ind[saved_label] = new_index
|
|
self.ind2lab[new_index] = saved_label
|
|
|
|
def add_unk(self, unk_label=DEFAULT_UNK):
|
|
self.unk_label = unk_label
|
|
return self.add_label(unk_label)
|
|
|
|
def _next_index(self):
|
|
index = self.starting_index
|
|
while index in self.ind2lab:
|
|
index += 1
|
|
|
|
return index
|
|
|
|
def is_continuous(self):
|
|
indices = sorted(self.ind2lab.keys())
|
|
return self.starting_index in indices and all(j - i == 1 for i, j in zip(indices[:-1], indices[1:]))
|
|
|
|
def encode_label(self, label, allow_unk=True):
|
|
self._assert_len()
|
|
|
|
try:
|
|
return self.lab2ind[label]
|
|
except KeyError:
|
|
if hasattr(self, "unk_label") and allow_unk: return self.lab2ind[self.unk_label]
|
|
elif hasattr(self, "unk_label") and not allow_unk: raise KeyError
|
|
elif not hasattr(self, "unk_label") and allow_unk: raise KeyError
|
|
else: raise KeyError
|
|
|
|
def encode_label_torch(self, label, allow_unk=True):
|
|
return torch.LongTensor([self.encode_label(label, allow_unk)])
|
|
|
|
def encode_sequence(self, sequence, allow_unk=True):
|
|
self._assert_len()
|
|
return [self.encode_label(label, allow_unk) for label in sequence]
|
|
|
|
def encode_sequence_torch(self, sequence, allow_unk=True):
|
|
return torch.LongTensor([self.encode_label(label, allow_unk) for label in sequence])
|
|
|
|
def decode_torch(self, x):
|
|
self._assert_len()
|
|
decoded = []
|
|
|
|
if x.ndim == 1:
|
|
for element in x:
|
|
decoded.append(self.ind2lab[int(element)])
|
|
else:
|
|
for subtensor in x:
|
|
decoded.append(self.decode_torch(subtensor))
|
|
|
|
return decoded
|
|
|
|
def decode_ndim(self, x):
|
|
self._assert_len()
|
|
try:
|
|
decoded = []
|
|
for subtensor in x:
|
|
decoded.append(self.decode_ndim(subtensor))
|
|
|
|
return decoded
|
|
except TypeError:
|
|
return self.ind2lab[int(x)]
|
|
|
|
@mark_as_saver
|
|
def save(self, path):
|
|
self._save_literal(path, self.lab2ind, self._get_extras())
|
|
|
|
def load(self, path):
|
|
lab2ind, ind2lab, extras = self._load_literal(path)
|
|
self.lab2ind = lab2ind
|
|
self.ind2lab = ind2lab
|
|
self._set_extras(extras)
|
|
|
|
@mark_as_loader
|
|
def load_if_possible(self, path, end_of_epoch=False):
|
|
del end_of_epoch
|
|
|
|
try:
|
|
self.load(path)
|
|
except FileNotFoundError:
|
|
return False
|
|
except (ValueError, SyntaxError):
|
|
return False
|
|
|
|
return True
|
|
|
|
def expect_len(self, expected_len):
|
|
self.expected_len = expected_len
|
|
|
|
def ignore_len(self):
|
|
self.expected_len = None
|
|
|
|
def _assert_len(self):
|
|
if hasattr(self, "expected_len"):
|
|
if self.expected_len is None: return
|
|
if len(self) != self.expected_len: raise RuntimeError
|
|
else:
|
|
self.ignore_len()
|
|
return
|
|
|
|
def _get_extras(self):
|
|
extras = {"starting_index": self.starting_index}
|
|
if hasattr(self, "unk_label"): extras["unk_label"] = self.unk_label
|
|
|
|
return extras
|
|
|
|
def _set_extras(self, extras):
|
|
if "unk_label" in extras: self.unk_label = extras["unk_label"]
|
|
self.starting_index = extras["starting_index"]
|
|
|
|
@staticmethod
|
|
def _save_literal(path, lab2ind, extras):
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
for label, ind in lab2ind.items():
|
|
f.write(repr(label) + CategoricalEncoder.VALUE_SEPARATOR + str(ind) + "\n")
|
|
|
|
f.write(CategoricalEncoder.EXTRAS_SEPARATOR)
|
|
|
|
for key, value in extras.items():
|
|
f.write(repr(key) + CategoricalEncoder.VALUE_SEPARATOR + repr(value) + "\n")
|
|
|
|
f.flush()
|
|
|
|
@staticmethod
|
|
def _load_literal(path):
|
|
lab2ind, ind2lab, extras = {}, {}, {}
|
|
|
|
with open(path, encoding="utf-8") as f:
|
|
for line in f:
|
|
if line == CategoricalEncoder.EXTRAS_SEPARATOR: break
|
|
literal, ind = line.strip().split(CategoricalEncoder.VALUE_SEPARATOR, maxsplit=1)
|
|
label = ast.literal_eval(literal)
|
|
lab2ind[label] = int(ind)
|
|
ind2lab[ind] = label
|
|
|
|
for line in f:
|
|
literal_key, literal_value = line.strip().split(CategoricalEncoder.VALUE_SEPARATOR, maxsplit=1)
|
|
extras[ast.literal_eval(literal_key)] = ast.literal_eval(literal_value)
|
|
|
|
return lab2ind, ind2lab, extras |