File size: 11,387 Bytes
45c1511 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
import json
import random
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from collections import Counter
from typing import NoReturn, List
from .distribution import NERDistribution, DatasetSplits, PrintDistribution
random.seed(41)
class DatasetSplitter(object):
"""
Prepare dataset splits - training, validation & testing splits
Compute ner distributions in our dataset. Compute ner distributions
based on which we create and store a dictionary which will contain
information about which notes (in a dataset) belong to which split.
Based on this distribution and whether we want to keep certain notes
grouped (e.g by patient) we assign notes to a split, such that the
final ner type distribution in each split is similar.
"""
def __init__(
self,
train_proportion: int = 70,
validation_proportion: int = 15,
test_proportion: int = 15
) -> NoReturn:
"""
Initialize the proportions of the splits.
Args:
train_proportion (int): Ratio of train dataset
validation_proportion (int): Ratio of validation dataset
test_proportion (int): Ratio of test dataset
"""
self._train_proportion = train_proportion
self._validation_proportion = validation_proportion
self._test_proportion = test_proportion
self._split = None
self._lookup_split = dict()
def get_split(self, split: str) -> List[str]:
return [key for key in self._lookup_split[split].keys()]
def set_split(self, split: str) -> NoReturn:
"""
Set the split that you are currently checking/processing.
Based on the split you can perform certain checks and
computation. Once the split is set, read the information
present in the split_info_path. Extract only the information
belonging to the split. Create a hash map where we have
the keys as the note_ids/patient ids that belong to the split. This hashmap
can then be used to check if a particular note belongs to this
split.
Args:
split (str): The split - train, test etc (depends on how you named it)
"""
if split not in ['train', 'validation', 'test']:
raise ValueError('Invalid split')
self._split = split
def __update_split(self, key: str) -> NoReturn:
"""
Update the hash map where we have
the keys (e.g note_id) that belong to the split. This hashmap
can then be used to check if a particular note belongs to this
split.
Args:
key (str): The key that identify the note belonging to the split
"""
self._lookup_split[self._split][key] = 1
def check_note(self, key: str) -> bool:
"""
Use the hash map created in the __get_i2b2_filter_map function
to check if the note (note_info) belongs to this split (train,
val, test etc). If it does, return true, else false
Args:
key (str): The key that identify the note belonging to the split
Returns:
(bool): True if the note belongs to the split, false otherwise
"""
if self._split is None:
raise ValueError('Split not set')
if self._lookup_split[self._split].get(key, False):
return True
else:
return False
def assign_splits(
self,
input_file: str,
spans_key: str = 'spans',
metadata_key: str = 'meta',
group_key: str = 'note_id',
margin: float = 0.3
) -> NoReturn:
"""
Get the dataset splits - training, validation & testing splits
Based on the NER distribution and whether we want to keep certain
notes grouped (e.g by patient). Return an iterable that contains
a tuple that contains the note_id and the split. This can be used
to filter notes based on the splits.
Args:
input_file (str): The input file
spans_key (str): The key where the note spans are present
metadata_key (str): The key where the note metadata is present
group_key (str): The key where the note group (e.g note_id or patient id etc) is present.
This field is what the notes will be grouped by, and all notes belonging
to this grouping will be in the same split
margin (float): Margin of error when maintaining proportions in the splits
"""
# Compute the distribution of NER types in the grouped notes.
# For example the distribution of NER types in all notes belonging to a
# particular patient
self._lookup_split = {
'train': dict(),
'validation': dict(),
'test': dict()
}
ner_distribution = NERDistribution()
for line in open(input_file, 'r'):
note = json.loads(line)
key = note[metadata_key][group_key]
ner_distribution.update_distribution(spans=note[spans_key], key=key)
# Initialize the dataset splits object
dataset_splits = DatasetSplits(
ner_distribution=ner_distribution,
train_proportion=self._train_proportion,
validation_proportion=self._validation_proportion,
test_proportion=self._test_proportion,
margin=margin
)
# Check the note and assign it to a split
for line in open(input_file, 'r'):
note = json.loads(line)
key = note[metadata_key][group_key]
split = dataset_splits.get_split(key=key)
self.set_split(split)
self.__update_split(key)
return None
def main() -> NoReturn:
"""
Prepare dataset splits - training, validation & testing splits
Compute ner distributions in our dataset. Based on this distribution
and whether we want to keep certain notes grouped (e.g by patient)
we assign notes to a split, such that the final ner type distribution
in each split is similar.
"""
# Compute the distribution of NER types in the grouped notes.
# For example the distribution of NER types in all notes belonging to a
# particular patient
# The following code sets up the arguments to be passed via CLI or via a JSON file
cli_parser = ArgumentParser(
description='configuration arguments provided at run time from the CLI',
formatter_class=ArgumentDefaultsHelpFormatter
)
cli_parser.add_argument(
'--input_file',
type=str,
required=True,
help='the the jsonl file that contains the notes'
)
cli_parser.add_argument(
'--spans_key',
type=str,
default='spans',
help='the key where the note spans is present in the json object'
)
cli_parser.add_argument(
'--metadata_key',
type=str,
default='meta',
help='the key where the note metadata is present in the json object'
)
cli_parser.add_argument(
'--group_key',
type=str,
default='note_id',
help='the key to group notes by in the json object'
)
cli_parser.add_argument(
'--train_proportion',
type=int,
default=70,
help='ratio of train dataset'
)
cli_parser.add_argument(
'--train_file',
type=str,
default=None,
help='The file to store the train data'
)
cli_parser.add_argument(
'--validation_proportion',
type=int,
default=15,
help='ratio of validation dataset'
)
cli_parser.add_argument(
'--validation_file',
type=str,
default=None,
help='The file to store the validation data'
)
cli_parser.add_argument(
'--test_proportion',
type=int,
default=15,
help='ratio of test dataset'
)
cli_parser.add_argument(
'--test_file',
type=str,
default=None,
help='The file to store the test data'
)
cli_parser.add_argument(
'--margin',
type=float,
default=0.3,
help='margin of error when maintaining proportions in the splits'
)
cli_parser.add_argument(
'--print_dist',
action='store_true',
help='whether to print the label distribution in the splits'
)
args = cli_parser.parse_args()
dataset_splitter = DatasetSplitter(
train_proportion=args.train_proportion,
validation_proportion=args.validation_proportion,
test_proportion=args.test_proportion
)
dataset_splitter.assign_splits(
input_file=args.input_file,
spans_key=args.spans_key,
metadata_key=args.metadata_key,
group_key=args.group_key,
margin=args.margin
)
if args.train_proportion > 0:
with open(args.train_file, 'w') as file:
for line in open(args.input_file, 'r'):
note = json.loads(line)
key = note[args.metadata_key][args.group_key]
dataset_splitter.set_split('train')
if dataset_splitter.check_note(key):
file.write(json.dumps(note) + '\n')
if args.validation_proportion > 0:
with open(args.validation_file, 'w') as file:
for line in open(args.input_file, 'r'):
note = json.loads(line)
key = note[args.metadata_key][args.group_key]
dataset_splitter.set_split('validation')
if dataset_splitter.check_note(key):
file.write(json.dumps(note) + '\n')
if args.test_proportion > 0:
with open(args.test_file, 'w') as file:
for line in open(args.input_file, 'r'):
note = json.loads(line)
key = note[args.metadata_key][args.group_key]
dataset_splitter.set_split('test')
if dataset_splitter.check_note(key):
file.write(json.dumps(note) + '\n')
if args.print_dist:
# Read the dataset splits file and compute the NER type distribution
key_counts = Counter()
ner_distribution = NERDistribution()
for line in open(args.input_file, 'r'):
note = json.loads(line)
key = note[args.metadata_key][args.group_key]
key_counts[key] += 1
ner_distribution.update_distribution(spans=note[args.spans_key], key=key)
print_distribution = PrintDistribution(ner_distribution=ner_distribution, key_counts=key_counts)
train_splits = dataset_splitter.get_split('train')
validation_splits = dataset_splitter.get_split('validation')
test_splits = dataset_splitter.get_split('test')
all_splits = train_splits + validation_splits + test_splits
# Print distribution for each split
print_distribution.split_distribution(split='total', split_info=all_splits)
print_distribution.split_distribution(split='train', split_info=train_splits)
print_distribution.split_distribution(split='validation', split_info=validation_splits)
print_distribution.split_distribution(split='test', split_info=test_splits)
if __name__ == "__main__":
main()
|