|
import random |
|
from collections import Counter |
|
from typing import NoReturn |
|
|
|
from .ner_distribution import NERDistribution |
|
|
|
random.seed(41) |
|
|
|
|
|
class DatasetSplits(object): |
|
""" |
|
Prepare dataset splits - training, validation & testing splits |
|
Compute ner distributions in the dataset. Based on this we assign |
|
notes to different splits and at the same time, we keep the distribution of |
|
NER types in each split similar. . |
|
Keep track of the split information - which notes are present in which split. |
|
The label distribution in each split, the number of notes in each split. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
ner_distribution: NERDistribution, |
|
train_proportion: int, |
|
validation_proportion: int, |
|
test_proportion: int, |
|
margin: float |
|
) -> NoReturn: |
|
""" |
|
Maintain split information. Assign notes based on the proportion of |
|
the splits, while keeping the label distribution in each split similar. |
|
Keep track of the split information - which notes are present in which split. |
|
The label distribution in each split, the number of notes in each split. |
|
Keep track of the dataset splits and the counts in each split etc. |
|
These will be used to assign the different notes to different |
|
splits while keeping the proportion of ner similar in each split. |
|
Get the maximum number of ner that can be present in the train, |
|
validation and test split. The total count will be used to |
|
calculate the current proportion of ner in the split. This can be used |
|
to keep the proportion of ner types consistent among different splits |
|
Args: |
|
ner_distribution (NERDistribution): The NER distribution in the dataset |
|
train_proportion (int): Ratio of train dataset |
|
validation_proportion (int): Ratio of validation dataset |
|
test_proportion (int): Ratio of test dataset |
|
margin (float): Margin by which the label distribution can be exceeded in the split |
|
""" |
|
self._ner_distribution = ner_distribution |
|
|
|
total_distribution = Counter() |
|
for key, counts in ner_distribution.get_ner_distribution().items(): |
|
for label, count in counts.items(): |
|
total_distribution[label] += count |
|
|
|
self._total_ner = sum(total_distribution.values()) |
|
self._label_dist_percentages = { |
|
ner_type: float(count) / self._total_ner * 100 if self._total_ner else 0 |
|
for ner_type, count in total_distribution.items() |
|
} |
|
self._margin = margin |
|
|
|
self._splits = ['train', 'validation', 'test'] |
|
self._split_weights = None |
|
self._splits_info = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._processed_keys = dict() |
|
|
|
|
|
|
|
|
|
|
|
|
|
train_ner_count = int(train_proportion * self._total_ner / 100) |
|
validation_ner_count = int(validation_proportion * self._total_ner / 100) |
|
test_ner_count = int(test_proportion * self._total_ner / 100) |
|
|
|
|
|
|
|
self._split_weights = [train_proportion, validation_proportion, test_proportion] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._splits_info = {'train': {'remain': ['validation', 'test'], |
|
'total': train_ner_count, |
|
'remain_weights': [validation_proportion, test_proportion], |
|
'groups': list(), 'number_of_notes': 0, 'label_dist': Counter()}, |
|
'validation': {'remain': ['train', 'test'], |
|
'total': validation_ner_count, |
|
'remain_weights': [train_proportion, test_proportion], |
|
'groups': list(), 'number_of_notes': 0, 'label_dist': Counter()}, |
|
'test': {'remain': ['validation', 'train'], |
|
'total': test_ner_count, |
|
'remain_weights': [validation_proportion, train_proportion], |
|
'groups': list(), 'number_of_notes': 0, 'label_dist': Counter()}} |
|
|
|
def __set_split(self, split: str) -> NoReturn: |
|
""" |
|
Set the split that you are currently checking/processing. |
|
Based on the split you can perform certain checks and |
|
computation for that split. |
|
Args: |
|
split (str): The split - train, validation or test |
|
""" |
|
self._split = split |
|
|
|
def __update_label_dist(self, distribution: Counter) -> NoReturn: |
|
""" |
|
Once we have determined that a note can be added to the split we need to |
|
update the current count of the ner types in the split. So we pass the ner counts |
|
in the note that will be updated and update the counts of the ner types in the split. |
|
Args: |
|
distribution (Counter): Contains the ner type and it's counts (distribution) |
|
""" |
|
self._splits_info[self._split]['label_dist'].update(distribution) |
|
|
|
def __update_groups(self, note_group_key: str) -> NoReturn: |
|
""" |
|
Once we have determined that a note can be added to the split, we append |
|
to a list some distinct element of the note (e.g note_id). This list will |
|
contain the note_ids of the notes that belong to this split. |
|
Args: |
|
note_group_key (str): Contains the note metadata - e.g note_id, institute etc |
|
""" |
|
self._processed_keys[note_group_key] = self._split |
|
self._splits_info[self._split]['groups'].append(note_group_key) |
|
|
|
def __check_split(self, distribution: Counter) -> bool: |
|
""" |
|
This function is used to check the resulting ner distribution in the split on adding this |
|
note to the split. We check how the proportion of ner changes if this note is added to |
|
the split. If the proportion exceeds the desired proportion then we return false |
|
to indicate that adding this note will upset the ner distribution across splits, so we should |
|
instead check adding this note to another split. If it does not update the balance then we return |
|
True, which means we can add this note to this split. The desired proportion of ner is passed |
|
in the percentages argument - where we have the desired proportion for each ner type. |
|
Args: |
|
distribution (Counter): Contains the mapping between ner type and count |
|
Returns: |
|
(bool): True if the note can be added to the split, false otherwise |
|
""" |
|
|
|
split_label_dist = self._splits_info[self._split]['label_dist'] |
|
|
|
|
|
split_total = self._splits_info[self._split]['total'] |
|
|
|
|
|
if split_total == 0: |
|
return False |
|
for ner_type, count in distribution.items(): |
|
percentage = (split_label_dist.get(ner_type, 0) + count) / split_total * 100 |
|
|
|
|
|
|
|
if percentage > self._label_dist_percentages[ner_type] + self._margin: |
|
return False |
|
return True |
|
|
|
def get_split(self, key: str) -> str: |
|
""" |
|
Assign a split to the note - based on the distribution of ner types in the note |
|
and the distribution of ner types in the split. Essentially assign a note to a split |
|
such that the distribution of ner types in each split is similar, once all notes have |
|
been assigned to their respective splits. |
|
Args: |
|
key (str): The note id or patient id of the note (some grouping key) |
|
Returns: |
|
(str): The split |
|
""" |
|
current_splits = self._splits |
|
current_weights = self._split_weights |
|
distribution = self._ner_distribution.get_group_distribution(key=key) |
|
if self._processed_keys.get(key, False): |
|
return self._processed_keys[key] |
|
while True: |
|
|
|
check_split = random.choices(current_splits, current_weights)[0] |
|
self.__set_split(check_split) |
|
|
|
|
|
|
|
|
|
|
|
include = self.__check_split(distribution=distribution) |
|
if include: |
|
self.__update_groups(key) |
|
self.__update_label_dist(distribution=distribution) |
|
return check_split |
|
else: |
|
|
|
if len(current_splits) == 3: |
|
current_splits = self._splits_info[check_split]['remain'] |
|
current_weights = self._splits_info[check_split]['remain_weights'] |
|
|
|
elif len(current_splits) == 2 and current_weights[1 - current_splits.index(check_split)] != 0: |
|
index = current_splits.index(check_split) |
|
current_splits = [current_splits[1 - index]] |
|
current_weights = [100] |
|
|
|
else: |
|
current_splits = self._splits |
|
current_weights = self._split_weights |
|
check_split = random.choices(current_splits, current_weights)[0] |
|
self.__set_split(check_split) |
|
self.__update_groups(key) |
|
self.__update_label_dist(distribution=distribution) |
|
return check_split |
|
|