diff --git a/deid/__init__.py b/deid/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb8ec063f819a03b114db2c6261aa8ae07e17774 --- /dev/null +++ b/deid/__init__.py @@ -0,0 +1,2 @@ +from .text_deid import TextDeid +__all__ = ["TextDeid"] diff --git a/deid/text_deid.py b/deid/text_deid.py new file mode 100644 index 0000000000000000000000000000000000000000..84d2854bcbaa7f65d92ea0202723003890b06fca --- /dev/null +++ b/deid/text_deid.py @@ -0,0 +1,307 @@ +import json +import re +from argparse import ArgumentParser +from typing import Sequence, List, Tuple, Mapping, Union, Any, Type + +import regex +from seqeval.scheme import IOB1, IOB2, IOBES, BILOU, Entities + +from .utils import remove, replace_tag_type, replace_informative + + +class TextDeid(object): + + def __init__(self, notation, span_constraint): + self._span_constraint = span_constraint + if self._span_constraint == 'strict': + self._scheme = TextDeid.__get_scheme('IO') + elif self._span_constraint == 'super_strict': + self._scheme = TextDeid.__get_scheme('IO') + else: + self._scheme = TextDeid.__get_scheme(notation) + + def decode(self, tokens, predictions): + if self._span_constraint == 'exact': + return predictions + elif self._span_constraint == 'strict': + return TextDeid.__get_relaxed_predictions(predictions) + elif self._span_constraint == 'super_strict': + return TextDeid.__get_super_relaxed_predictions(tokens, predictions) + + def get_predicted_entities_positions( + self, + tokens: Sequence[Mapping[str, Union[str, int]]], + predictions: List[str], + suffix: bool + ) -> List[List[Union[Tuple[Union[str, int], Union[str, int]], Any]]]: + """ + Use the seqeval get_entities method, which goes through the predictions and returns + where the span starts and ends. - [O, O, B-AGE, I-AGE, O, O] this will return + spans starts at token 2 and ends at token 3 - with type AGE. We then extract the + position of the token in the note (character position) - so we return that + this span starts at 32 and ends at 37. The function then returns a nested list + that contains a tuple of tag type and tag position (character positions). + Example: [[(3, 9), LOC], [(34, 41), PATIENT], ...]] + Args: + tokens (Sequence[Mapping[str, Union[str, int]]]): The list of tokens in the note + predictions (Sequence[str]): The list of predictions for the note + suffix (str): Whether the B, I etc is in the prefix or the suffix + Returns: + positions_info (List[Tuple[Tuple[int, int], str]])): List containing tuples of tag positions and tag type + """ + positions_info = list() + entities = Entities(sequences=[predictions], scheme=self._scheme, suffix=suffix) + for entity_list in entities.entities: + for entity in entity_list: + position = (tokens[entity.start]['start'], tokens[entity.end - 1]['end']) + positions_info.append([position, entity.tag]) + return positions_info + + def run_deid( + self, + input_file, + predictions_file, + deid_strategy, + keep_age: bool = False, + metadata_key: str = 'meta', + note_id_key: str = 'note_id', + tokens_key: str = 'tokens', + predictions_key: str = 'predictions', + text_key: str = 'text' + ): + # Store note_id to note mapping + note_map = dict() + for line in open(input_file, 'r'): + note = json.loads(line) + note_id = note[metadata_key][note_id_key] + note_map[note_id] = note + # Go through note predictions and de identify the note accordingly + for line in open(predictions_file, 'r'): + note = json.loads(line) + # Get the text using the note_id for this note from the note_map dict + note_id = note[note_id_key] + # Get the note from the note_map dict + deid_note = note_map[note_id] + # Get predictions + predictions = self.decode(tokens=note[tokens_key], predictions=note[predictions_key]) + # Get entities and their positions + entity_positions = self.get_predicted_entities_positions( + tokens=note[tokens_key], + predictions=predictions, + suffix=False + ) + yield TextDeid.__get_deid_text( + deid_note=deid_note, + entity_positions=entity_positions, + deid_strategy=deid_strategy, + keep_age=keep_age, + text_key=text_key + ) + + @staticmethod + def __get_deid_text( + deid_note, + entity_positions, + deid_strategy, + keep_age: bool = False, + text_key: str = 'text' + ): + tag_mapping = TextDeid.__get_tag_mapping(deid_strategy=deid_strategy) + age_pattern = '((? Union[Type[IOB2], Type[IOBES], Type[BILOU], Type[IOB1]]: + """ + Get the seqeval scheme based on the notation + Args: + notation (str): The NER notation + Returns: + (Union[IOB2, IOBES, BILOU, IOB1]): The seqeval scheme + """ + if notation == 'BIO': + return IOB2 + elif notation == 'BIOES': + return IOBES + elif notation == 'BILOU': + return BILOU + elif notation == 'IO': + return IOB1 + else: + raise ValueError('Invalid Notation') + + +def main(): + # The following code sets up the arguments to be passed via CLI or via a JSON file + cli_parser = ArgumentParser(description='configuration arguments provided at run time from the CLI') + cli_parser.add_argument( + '--input_file', + type=str, + required=True, + help='the the jsonl file that contains the notes' + ) + cli_parser.add_argument( + '--predictions_file', + type=str, + required=True, + help='the location where the predictions are' + ) + cli_parser.add_argument( + '--span_constraint', + type=str, + required=True, + choices=['exact', 'strict', 'super_strict'], + help='whether we want to modify the predictions, make the process of removing phi more struct etc' + ) + cli_parser.add_argument( + '--notation', + type=str, + + required=True, + help='the NER notation in the predictions' + ) + cli_parser.add_argument( + '--deid_strategy', + type=str, + required=True, + choices=['remove', 'replace_tag_type', 'replace_informative'], + help='The strategy ' + ) + cli_parser.add_argument( + '--keep_age', + action='store_true', + help='whether to keep ages below 89' + ) + cli_parser.add_argument( + '--text_key', + type=str, + default='text', + help='the key where the note text is present in the json object' + ) + cli_parser.add_argument( + '--metadata_key', + type=str, + default='meta', + help='the key where the note metadata is present in the json object' + ) + cli_parser.add_argument( + '--note_id_key', + type=str, + default='note_id', + help='the key where the note id is present in the json object' + ) + cli_parser.add_argument( + '--tokens_key', + type=str, + default='tokens', + help='the key where the tokens for the notes are present in the json object' + ) + cli_parser.add_argument( + '--predictions_key', + type=str, + default='predictions', + help='the key where the note predictions is present in the json object' + ) + cli_parser.add_argument( + '--output_file', + type=str, + required=True, + help='the location we would write the deid notes' + ) + # Parse args + args = cli_parser.parse_args() + text_deid = TextDeid(notation=args.notation, span_constraint=args.span_constraint) + deid_notes = text_deid.run_deid( + input_file=args.input_file, + predictions_file=args.predictions_file, + deid_strategy=args.deid_strategy, + keep_age=args.keep_age, + metadata_key=args.metadata_key, + note_id_key=args.note_id_key, + tokens_key=args.tokens_key, + predictions_key=args.predictions_key, + text_key=args.text_key + ) + # Write the dataset to the output file + with open(args.output_file, 'w') as file: + for deid_note in deid_notes: + file.write(json.dumps(deid_note) + '\n') + + +if __name__ == "__main__": + # Get deid notes + main() diff --git a/deid/utils.py b/deid/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9becc49777ec4ecbee35bb8756615f3847815628 --- /dev/null +++ b/deid/utils.py @@ -0,0 +1,43 @@ +def remove(): + return {'PATIENT': '', + 'STAFF': '', + 'AGE': '', + 'DATE': '', + 'PHONE': '', + 'MRN': '', + 'ID': '', + 'EMAIL': '', + 'PATORG': '', + 'LOC': '', + 'HOSP': '', + 'OTHERPHI': ''} + + +def replace_tag_type(): + return {'PATIENT': 'PATIENT', + 'STAFF': 'STAFF', + 'AGE': 'AGE', + 'DATE': 'DATE', + 'PHONE': 'PHONE', + 'MRN': 'MRN', + 'ID': 'ID', + 'EMAIL': 'EMAIL', + 'PATORG': 'PATORG', + 'LOC': 'LOCATION', + 'HOSP': 'HOSPITAL', + 'OTHERPHI': 'OTHERPHI'} + + +def replace_informative(): + return {'PATIENT': '<>', + 'STAFF': '<>', + 'AGE': '<>', + 'DATE': '<>', + 'PHONE': '<>', + 'MRN': '<>', + 'ID': '<>', + 'EMAIL': '<>', + 'PATORG': '<>', + 'LOC': '<>', + 'HOSP': '<>', + 'OTHERPHI': '<>'} diff --git a/ner_datasets/__init__.py b/ner_datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e0f1ba62b094aeaaeb0cdafec73d260e904098b4 --- /dev/null +++ b/ner_datasets/__init__.py @@ -0,0 +1,5 @@ +from ehr_deidentification.sequence_tagging.dataset_builder.ner_labels import NERLabels +from .span_fixer import SpanFixer +from .dataset_splitter import DatasetSplitter +from .dataset_creator import DatasetCreator +__all__ = ["NERLabels", "SpanFixer", "DatasetSplitter", "DatasetCreator"] diff --git a/ner_datasets/__pycache__/__init__.cpython-37.pyc b/ner_datasets/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44d90674cc7af86e9b992b2390051294fbcfd92d Binary files /dev/null and b/ner_datasets/__pycache__/__init__.cpython-37.pyc differ diff --git a/ner_datasets/dataset_builder/__init__.py b/ner_datasets/dataset_builder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b419eb6cad93e5db880ff0a9ca1152dda6e21946 --- /dev/null +++ b/ner_datasets/dataset_builder/__init__.py @@ -0,0 +1,3 @@ +from .dataset import Dataset +from .sentence_dataset import SentenceDataset +__all__ = ["SentenceDataset", "Dataset"] diff --git a/ner_datasets/dataset_builder/dataset.py b/ner_datasets/dataset_builder/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..05c519eaeb7f3dd6fa027a66d08bbc0aff7e729e --- /dev/null +++ b/ner_datasets/dataset_builder/dataset.py @@ -0,0 +1,119 @@ +import random +import re +from typing import Iterable, Dict, Sequence, Union, Mapping, Optional, List + +from .labels import NERTokenLabels, NERPredictTokenLabels, MismatchError + +random.seed(41) + + +class Dataset(object): + """ + Build a NER token classification dataset. Each token should have a corresponding label + based on the annotated spans + For training we will build the dataset using the annotated spans (e.g from prodigy) + For predictions we will assign default labels. to keep the format of the dataset the same + The dataset is on a sentence level, i.e each note is split into sentences and the + task is run on a sentence level. Even the predictions are run on a sentence level + The dataset would be something like: + Tokens: [tok1, tok2, ... tok n] + Labels: [lab1, lab2, ... lab n] + For the prediction mode the labels would be: [default, default, default .... default] + This script can also be used for predictions, the Labels will be filled with some + default value. This is done so that we can use the same script for building a dataset to train a model + and a dataset to obtain predictions using a model + """ + + def __init__( + self, + sentencizer, + tokenizer + ): + """ + Build a NER token classification dataset + For training we will build the dataset using the annotated spans (e.g from prodigy) + For predictions we will assign default labels. + The dataset is on a sentence level, i.e each note is split into sentences and the de-id + task is run on a sentence level. Even the predictions are run on a sentence level + The dataset would be something like: + Tokens: [tok1, tok2, ... tok n] + Labels: [lab1, lab2, ... lab n] + This script can also be used for predictions, the Labels will be filled with some + default value. This is done so that we can use the same script for building a dataset to train a model + and a dataset to obtain predictions using a model + Args: + sentencizer (Union[SpacySentencizer, MimicStanzaSentencizer, NoteSentencizer]): The sentencizer to use for + splitting notes into + sentences + tokenizer (Union[ClinicalSpacyTokenizer, SpacyTokenizer, CoreNLPTokenizer]): The tokenizer to use for + splitting text into tokens + """ + self._sentencizer = sentencizer + self._tokenizer = tokenizer + + def get_tokens( + self, + text: str, + spans: Optional[List[Mapping[str, Union[str, int]]]] = None, + notation: str = 'BIO', + token_text_key: str = 'text', + label_key: str = 'label' + ) -> Iterable[Sequence[Dict[str, Union[str, int]]]]: + """ + Get a nested list of tokens where the the inner list represents the tokens in the + sentence and the outer list will contain all the sentences in the note + Args: + text (str): The text present in the note + spans (Optional[List[Mapping[str, Union[str, int]]]]): The NER spans in the note. This will be none if + building the dataset for prediction + notation (str): The notation we will be using for the label scheme (e.g BIO, BILOU etc) + token_text_key (str): The key where the note text is present + label_key (str): The key where the note label for each token is present + Returns: + Iterable[Sequence[Dict[str, Union[str, int]]]]: Iterable that iterates through all the sentences + and yields the list of tokens in each sentence + """ + # Initialize the object that will be used to align tokens and spans based on the notation + # as mentioned earlier - this will be used only when mode is train - because we have + # access to labelled spans for the notes + if spans is None: + label_spans = NERPredictTokenLabels('O') + else: + label_spans = NERTokenLabels(spans=spans, notation=notation) + # Iterate through the sentences in the note + for sentence in self._sentencizer.get_sentences(text=text): + # This is used to determine the position of the tokens with respect to the entire note + offset = sentence['start'] + # Keeps track of the tokens in the sentence + tokens = list() + for token in self._tokenizer.get_tokens(text=sentence['text']): + # Get the token position (start, end) in the note + token['start'] += offset + token['end'] += offset + if token[token_text_key].strip() in ['\n', '\t', ' ', ''] or token['start'] == token['end']: + continue + # Shorten consecutive sequences of special characters, this can prevent BERT from truncating + # extremely long sentences - that could arise because of these characters + elif re.search('(\W|_){9,}', token[token_text_key]): + print('WARNING - Shortening a long sequence of special characters from {} to 8'.format( + len(token[token_text_key]))) + token[token_text_key] = re.sub('(?P(\W|_)){8,}', '\g' * 8, + token[token_text_key]) + elif len(token[token_text_key].split(' ')) != 1: + print('WARNING - Token contains a space character - will be replaced with hyphen') + token[token_text_key] = token[token_text_key].replace(' ', '-') + # Get the labels for each token based on the notation (BIO) + # In predict mode - the default label (e.g O) will be assigned + try: + # Get the label for the token - based on the notation + label = label_spans.get_labels(token=token) + if label[2:] == 'OTHERISSUE': + raise ValueError('Fix OTHERISSUE spans') + # Check if there is a token and span mismatch, i.e the token and span does not align + except MismatchError: + print(token) + raise ValueError('Token-Span mismatch') + token[label_key] = label + tokens.append(token) + if tokens: + yield tokens diff --git a/ner_datasets/dataset_builder/labels/__init__.py b/ner_datasets/dataset_builder/labels/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b1e83a935e1afe8749eeb745772415dfe6cde2c --- /dev/null +++ b/ner_datasets/dataset_builder/labels/__init__.py @@ -0,0 +1,4 @@ +from .mismatch_error import MismatchError +from .ner_token_labels import NERTokenLabels +from .ner_predict_token_labels import NERPredictTokenLabels +__all__=["NERTokenLabels", "NERPredictTokenLabels", "MismatchError"] \ No newline at end of file diff --git a/ner_datasets/dataset_builder/labels/mismatch_error.py b/ner_datasets/dataset_builder/labels/mismatch_error.py new file mode 100644 index 0000000000000000000000000000000000000000..287a02baa5661989c37d6088ff1071b990a08eb9 --- /dev/null +++ b/ner_datasets/dataset_builder/labels/mismatch_error.py @@ -0,0 +1,7 @@ +# Exception thrown when there is a mismatch between a token and span +# The token and spans don't line up due to a tokenization issue +# E.g - 79M - span is AGE - 79, but token is 79M +# There is a mismatch and an error will be thrown - that is the token does +# not line up with the span +class MismatchError(Exception): + pass diff --git a/ner_datasets/dataset_builder/labels/ner_predict_token_labels.py b/ner_datasets/dataset_builder/labels/ner_predict_token_labels.py new file mode 100644 index 0000000000000000000000000000000000000000..5b203989f4243c1f7d1367952ba6f49c63515cf5 --- /dev/null +++ b/ner_datasets/dataset_builder/labels/ner_predict_token_labels.py @@ -0,0 +1,30 @@ +from typing import Mapping, Union, NoReturn + + +class NERPredictTokenLabels(object): + """ + Assign a default label while creating the dataset for prediction. + This is done since the sequence tagging code expects the input + file to contain a labels field, hence we assign a default label + to meet this requirement + """ + + def __init__(self, default_label: str) -> NoReturn: + """ + Initialize the default label + Args: + default_label (str): Default label that will be used + """ + # Keeps track of all the spans (list) in the text (note) + self._default_label = default_label + + def get_labels(self, token: Mapping[str, Union[str, int]]) -> str: + """ + Given a token, return the default label. + Args: + token (Mapping[str, Union[str, int]]): Contains the token text, start and end position of the token + in the text + Returns: + default_label (str): default label + """ + return self._default_label diff --git a/ner_datasets/dataset_builder/labels/ner_token_labels.py b/ner_datasets/dataset_builder/labels/ner_token_labels.py new file mode 100644 index 0000000000000000000000000000000000000000..f5b18b2f7512c903bf53e7fbdc422d682c377058 --- /dev/null +++ b/ner_datasets/dataset_builder/labels/ner_token_labels.py @@ -0,0 +1,156 @@ +from typing import Mapping, Union, Sequence, List +from .mismatch_error import MismatchError + + +class NERTokenLabels(object): + """ + This class is used to align tokens with the spans + Each token is assigned one of the following labels + 'B-LABEL', 'I-LABEL', 'O'. For example the text + 360 Longwood Avenue is 2 tokens - [360, Longwood, Avenue] + and each token would be assigned the following labels + [B-LOC, I-LOC, I-LOC] (this would also depend on what + notation we are using). Generally the data after prodigy + annotation has all the tokens and all the spans. + We would have tokens:[tok1, tok2, ... tokn] + and spans:[span1:[tok1, tok2, tok3], span2:[tok7], ... span k] + This would be used to convert into the format we are using + which is assign the label to each token based on which span it + belongs to. + """ + + def __init__( + self, + spans: List[Mapping[str, Union[str, int]]], + notation: str + ): + """ + Initialize variables that will be used to align tokens + and span labels. The spans variable will contain all the spans + in the note. Notation is whether we would like to use BIO, IO, BILOU, + when assigning the label to each token based on which span it belongs to. + Keep track of the total number of spans etc. + Args: + spans (Sequence[Mapping[str, Union[str, int]]]): List of all the spans in the text + notation (str): NER label notation + """ + # Keeps track of all the spans (list) in the text (note) + self._spans = spans + for span in self._spans: + if type(span['start']) != int or type(span['end']) != int: + raise ValueError('The start and end keys of the span must be of type int') + self._spans.sort(key=lambda _span: (_span['start'], _span['end'])) + # The current span is the first element of the list + self._current_span = 0 + # Boolean variable that indicates whether the token is inside + # the span (I-LABEL) + self._inside = False + # Total number of spans + self._span_count = len(self._spans) + # Depending on the notation passed, we will return the label for + # the token accordingly + if notation == 'BIO': + self._prefix_single = 'B-' + self._prefix_begin = 'B-' + self._prefix_inside = 'I-' + self._prefix_end = 'I-' + self._prefix_outside = 'O' + elif notation == 'BIOES': + self._prefix_single = 'S-' + self._prefix_begin = 'B-' + self._prefix_inside = 'I-' + self._prefix_end = 'E-' + self._prefix_outside = 'O' + elif notation == 'BILOU': + self._prefix_single = 'U-' + self._prefix_begin = 'B-' + self._prefix_inside = 'I-' + self._prefix_end = 'L-' + self._prefix_outside = 'O' + elif notation == 'IO': + self._prefix_single = 'I-' + self._prefix_begin = 'I-' + self._prefix_inside = 'I-' + self._prefix_end = 'I-' + self._prefix_outside = 'O' + + def __check_begin(self, token: Mapping[str, Union[str, int]]) -> str: + """ + Given a token, return the label (B-LABEL) and check whether the token + covers the entire span or is a sub set of the span + Args: + token (Mapping[str, Union[str, int]]): Contains the token text, start and end position of the token + in the text + Returns: + (str): The label - 'B-LABEL' + """ + # Set the inside flag to true to indicate that the next token that is checked + # will be checked to see if it belongs 'inside' the span + self._inside = True + if token['end'] > int(self._spans[self._current_span]['end']): + raise MismatchError('Span and Token mismatch - Begin Token extends longer than the span') + # If this token does not cover the entire span then we expect another token + # to be in the span and that token should be assigned the I-LABEL + elif token['end'] < int(self._spans[self._current_span]['end']): + return self._prefix_begin + self._spans[self._current_span]['label'] + # If this token does cover the entire span then we set inside = False + # to indicate this span is complete and increment the current span + # to move onto the next span in the text + elif token['end'] == int(self._spans[self._current_span]['end']): + self._current_span += 1 + self._inside = False + return self._prefix_single + self._spans[self._current_span - 1]['label'] + + def __check_inside(self, token: Mapping[str, Union[str, int]]) -> str: + """ + Given a token, return the label (I-LABEL) and check whether the token + covers the entire span or is still inside the span. + Args: + token (Mapping[str, Union[str, int]]): Contains the token text, start and end position of the token + in the text + Returns: + (str): The label - 'I-LABEL' + """ + + if (token['start'] >= int(self._spans[self._current_span]['end']) + or token['end'] > int(self._spans[self._current_span]['end'])): + raise MismatchError('Span and Token mismatch - Inside Token starts after the span ends') + # If this token does not cover the entire span then we expect another token + # to be in the span and that token should be assigned the I-LABEL + elif token['end'] < int(self._spans[self._current_span]['end']): + return self._prefix_inside + self._spans[self._current_span]['label'] + # If this token does cover the entire span then we set inside = False + # to indicate this span is complete and increment the current span + # to move onto the next span in the text + elif token['end'] == int(self._spans[self._current_span]['end']): + self._current_span += 1 + self._inside = False + return self._prefix_end + self._spans[self._current_span - 1]['label'] + + def get_labels(self, token: Mapping[str, Union[str, int]]) -> str: + """ + Given a token, return the label (B-LABEL, I-LABEL, O) based on + the spans present in the text & the desired notation. + Args: + token (Mapping[str, Union[str, int]]): Contains the token text, start and end position of the token + in the text + Returns: + (str): One of the labels according to the notation - 'B-LABEL', 'I-LABEL', 'O' + """ + # If we have iterated through all the spans in the text (note), all the tokens that + # come after the last span will be marked as 'O' - since they don't belong to any span + if self._current_span >= self._span_count: + return self._prefix_outside + # Check if the span can be assigned the B-LABEL + if token['start'] == int(self._spans[self._current_span]['start']): + return self.__check_begin(token) + # Check if the span can be assigned the I-LABEL + elif token['start'] > int(self._spans[self._current_span]['start']) and self._inside is True: + return self.__check_inside(token) + # Check if the token is outside a span + elif self._inside is False and (token['end'] <= int(self._spans[self._current_span]['start'])): + return self._prefix_outside + else: + raise MismatchError( + 'Span and Token mismatch - the span and tokens don\'t line up. There might be a tokenization issue ' + 'that needs to be fixed') diff --git a/ner_datasets/dataset_builder/sentence_dataset.py b/ner_datasets/dataset_builder/sentence_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2448189a5986bfa8e61f157497513f74f6137b98 --- /dev/null +++ b/ner_datasets/dataset_builder/sentence_dataset.py @@ -0,0 +1,355 @@ +from collections import deque +from typing import Deque, List, Sequence, Iterable, Optional, NoReturn, Dict, Mapping, Union, Tuple + + +class SentenceDataset(object): + """ + When we mention previous sentence and next sentence, we don't mean exactly one sentence + but rather a previous chunk and a next chunk. This can include one or more sentences and + it does not mean that the sentence has to be complete (it can be cutoff in between) - hence a chunk + This class is used to build a dataset at the sentence + level. It takes as input all the tokenized sentences in the note. So the input is + a list of lists where the outer list represents the sentences in the note and the inner list + is a list of tokens in the sentence. It then returns a dataset where each sentence is + concatenated with the previous and a next chunk. This is done so that when we build a model + we can use the previous and next chunks to add context to the sentence/model. The weights and loss etc + will be computed and updated based on the current sentence. The previous and next chunks will + only be used to add context. We could have different sizes of previous and next chunks + depending on the position of the sentence etc. Essentially we build a sentence level dataset + where we can also provide context to the sentence by including the previous and next chunks + """ + + def __init__( + self, + max_tokens: int, + max_prev_sentence_token: int, + max_next_sentence_token: int, + default_chunk_size: int, + ignore_label: str + ) -> NoReturn: + """ + Set the maximum token length a given training example (sentence level) can have. + That is the total length of the current sentence + previous chunk + next chunk + We also set the the maximum length of the previous and next chunks. That is how many + tokens can be in these chunks. However if the total length exceeds, tokens in the + previous and next chunks will be dropped to ensure that the total length is < max_tokens + The default chunk size ensures that the length of the chunks will be a minimum number of + tokens based on the value passed. For example is default_chunk_size=10, the length + of the previous chunks and next chunks will be at least 10 tokens. + Args: + max_tokens (int): maximum token length a given training example (sentence level) can have + max_prev_sentence_token (int): The max chunk size for the previous chunks for a given sentence + (training/prediction example) in the note can have + max_next_sentence_token (int): The max chunk size for the next chunks for a given sentence + (training/prediction example) in the note can have + default_chunk_size (int): the training example will always include a chunk of this length + as part of the previous and next chunks + ignore_label (str): The label assigned to the previous and next chunks to distinguish + from the current sentence + """ + self._id_num = None + self._max_tokens = max_tokens + self._max_prev_sentence_token = max_prev_sentence_token + self._max_next_sentence_token = max_next_sentence_token + self._default_chunk_size = default_chunk_size + self._ignore_label = ignore_label + + @staticmethod + def chunker( + seq: Sequence[Mapping[str, Union[str, int]]], + size: int + ) -> Iterable[Sequence[Mapping[str, Union[str, int]]]]: + """ + Return chunks of the sequence. The size of each chunk will be based + on the value passed to the size argument. + Args: + seq (Sequence): maximum token length a given training example (sentence level) can have + size (int): The max chunk size for the chunks + Return: + (Iterable[Sequence[Mapping[str, Union[str, int]]]]): Iterable that iterates through fixed size chunks of + the input sequence chunked version of the sequence + + """ + return (seq[pos:pos + size] for pos in range(0, len(seq), size)) + + def get_previous_sentences(self, sent_tokens: Sequence[Sequence[Mapping[str, Union[str, int]]]]) -> List[Deque]: + """ + Go through all the sentences in the medical note and create a list of + previous sentences. The output of this function will be a list of chunks + where each index of the list contains the sentences (chunks) - (tokens) present before + the sentence at that index in the medical note. For example prev_sent[0] will + be empty since there is no sentence before the first sentence in the note + prev_sent[1] will be equal to sent[0], that is the previous sentence of the + second sentence will be the first sentence. We make use of deque, where we + start to deque elements when it start to exceed max_prev_sentence_token. This + list of previous sentences will be used to define the previous chunks + Args: + sent_tokens (Sequence[str]): Sentences in the note and + each element of the list contains a + list of tokens in that sentence + Returns: + previous_sentences (List[deque]): A list of deque objects where each index contains a + list (queue) of previous tokens (chunk) with respect + to the sentence represented by that index in the note + """ + previous_sentences = list() + # Create a queue and specify the capacity of the queue + # Tokens will be popped from the queue when the capacity is exceeded + prev_sentence = deque(maxlen=self._max_prev_sentence_token) + # The first previous chunk is empty since the first sentence in the note does not have + # anything before it + previous_sentences.append(prev_sentence.copy()) + # As we iterate through the list of sentences in the not, we add the tokens from the previous chunks + # to the the queue. Since we have a queue, as soon as the capacity is exceeded we pop tokens from + # the queue + for sent_token in sent_tokens[:-1]: + for token in sent_token: + prev_sentence.append(token) + # As soon as each sentence in the list is processed + # We add a copy of the current queue to a list - this list keeps track of the + # previous chunks for a sentence + previous_sentences.append(prev_sentence.copy()) + + return previous_sentences + + def get_next_sentences(self, sent_tokens: Sequence[Sequence[Mapping[str, Union[str, int]]]]) -> List[Deque]: + """ + Go through all the sentences in the medical note and create a list of + next sentences. The output of this function will be a list of lists + where each index of the list contains the list of sentences present after + the sentence at that index in the medical note. For example next_sent[-] will + be empty since there is no sentence after the last sentence in the note + next_sent[0] will be equal to sent[1:], that is the next sentence of the + first sentence will be the subsequent sentences. We make use of deque, where we + start to deque elements when it start to exceed max_next_sentence_token. This + list of previous sentences will be used to define the previous chunks + Args: + sent_tokens (Sequence[str]): Sentences in the note and each + element of the list contains a + list of tokens in that sentence + Returns: + next_sentences (List[deque]): A list of deque objects where each index contains a list (queue) + of next tokens (chunk) with respect to the sentence represented + by that index in the note + """ + # A list of next sentences is first created and reversed + next_sentences = list() + # Create a queue and specify the capacity of the queue + # Tokens will be popped from the queue when the capacity is exceeded + next_sentence = deque(maxlen=self._max_next_sentence_token) + # The first (which becomes the last chunk when we reverse this list) next chunk is empty since + # the last sentence in the note does not have + # anything after it + next_sentences.append(next_sentence.copy()) + for sent_token in reversed(sent_tokens[1:]): + for token in reversed(sent_token): + next_sentence.appendleft(token) + next_sentences.append(next_sentence.copy()) + # The list is reversed - since we went through the sentences in the reverse order in + # the earlier steps + return [next_sent for next_sent in reversed(next_sentences)] + + def get_sentences( + self, + sent_tokens: Sequence[Sequence[Mapping[str, Union[str, int]]]], + token_text_key: str = 'text', + label_key: str = 'label', + start_chunk: Optional[Sequence[Mapping[str, Union[str, int]]]] = None, + end_chunk: Optional[Sequence[Mapping[str, Union[str, int]]]] = None, + sub: bool = False + ) -> Iterable[Tuple[int, Dict[str, Union[List[Dict[str, Union[str, int]]], List[str]]]]]: + """ + When we mention previous sentence and next sentence, we don't mean exactly one sentence + but rather a previous chunk and a next chunk. This can include one or more sentences and + it does not mean that the sentence has to be complete (it can be cutoff in between) - hence a chunk + We iterate through all the tokenized sentences in the note. So the input is + a list of lists where the outer list represents the sentences in the note and the inner list + is a list of tokens in the sentence. It then returns a dataset where each sentence is + concatenated with the previous and the next sentence. This is done so that when we build a model + we can use the previous and next sentence to add context to the model. The weights and loss etc + will be computed and updated based on the current sentence. The previous and next sentence will + only be used to add context. We could have different sizes of previous and next chunks + depending on the position of the sentence etc. Since we split a note in several sentences which are + then used as training data. + ignore_label is used to differentiate between the current sentence and the previous and next + chunks. The chunks will have the label NA so that and the current sentence + will have the label (DATE, AGE etc) so that they can be distinguished. + If however we are building a dataset for predictions + the current sentence will have the default label O, but the next and previous chunks will still + have the label NA. However if the total length exceeds, tokens in the + previous and next chunks will be dropped to ensure that the total length is < max_tokens + The default chunk size ensures that the length of the chunks will be a minimum number of + tokens based on the value passed. For example is default_chunk_size=10, the length + of the previous chunks and next chunks will be at least 10 tokens. If the total length > max tokens + even after decreasing the sizes of the previous and next chunks, then we split this long + sentence into sub sentences and repeat the process described above. + Args: + sent_tokens (Sequence[Sequence[Mapping[str, Union[str, int]]]]): Sentences in the note and each sentence + contains the tokens (dict) in that sentence + the token dict object contains the + token text, start, end etc + token_text_key (str): Each sentence contains a list of tokens where each token is a dict. We use the text + key to extract the text of the token from the dictionary + label_key (str): Each sentence contains a list of tokens where each token is a dict. We use the label_key + key to extract the label of the token from the dictionary. (if it does not have a label + the default label will be assigned) + start_chunk (Optional[Sequence[Mapping[str, Union[str, int]]]]): Prefix the first sentence of with some + pre-defined chunk + end_chunk (Optional[Sequence[Mapping[str, Union[str, int]]]]): Suffix the last sentence of with some + pre-defined chunk + sub (bool): Whether the function is called to process sub-sentences (used when we are splitting + long sentences into smaller sub sentences to keep sentence length < max_tokens + Returns: + (Iterable[Tuple[int, Dict[str, Union[List[Dict[str, Union[str, int]]], List[str]]]]]): Iterate through the + returned sentences, + where each sentence + has the previous + chunks and next + chunks attached + to it. + """ + # Id num keeps track of the id of the sentence - that is the position the sentence occurs in + # the note. We keep the id of sub sentences the same as the sentence, so that the user + # knows that these sub sentences are chunked from a longer sentence. + # . Say length of sent 0 with the previous and next chunks is less than max_tokens + # we return sent 0 with id 0. For sent 1, say the length is longer, we split it into sub + # sentences - - we return SUB 1, and SUB 2 with id 1 - so we know that it belongs + # to in the note. + if not sub: + self._id_num = -1 + # Initialize the object that will take all the sentences in the note and return + # a dataset where each row represents a sentence in the note. The sentence in each + # row will also contain a previous chunk and next chunk (tokens) that will act as context + # when training the mode + # [ps1, ps 2, ps 3...ps-i], [cs1, cs2, ... cs-j], [ns, ns, ... ns-k] - as you can see the current sentence + # which is the sentence we train on (or predict on) will be in the middle - the surrounding tokens will + # provide context to the current sentence + # Get the previous sentences (chunks) for each sentence in the note + previous_sentences = self.get_previous_sentences(sent_tokens) + # Get the next sentences (chunks) for each sentence in the note + next_sentences = self.get_next_sentences(sent_tokens) + # For the note we are going to iterate through all the sentences in the note and + # concatenate each sentence with the previous and next chunks. (This forms the data that + # will be used for training/predictions) Each sentence with the concatenated chunks will be + # a training sample. We would do the same thing for getting predictions on a sentence as well + # The only difference would be the labels that are used. We would use the default label O for + # prediction and the annotated labels for prediction + if len(sent_tokens) != len(previous_sentences) or len(sent_tokens) != len(next_sentences): + raise ValueError('Sentence length mismatch') + for index, (previous_sent, current_sent, next_sent) in enumerate( + zip(previous_sentences, sent_tokens, next_sentences)): + sent_tokens_text = list() + sent_labels = list() + sent_toks = list() + # Get the tokens and labels for the current sentence + for token in current_sent: + # We store this, if we need to process sub sentences when a sentence exceeds max_tokens + sent_toks.append(token) + sent_tokens_text.append(token[token_text_key]) + sent_labels.append(token[label_key]) + # We check if the number of tokens in teh current sentence + previous chunk + # + next chunk exceeds max tokens. If it does we start popping tokens from the previous and next chunks + # until the number of tokens is equal to max tokens + previous_sent_length = len(previous_sent) + current_sent_length = len(sent_tokens_text) + next_sent_length = len(next_sent) + total_length = previous_sent_length + current_sent_length + next_sent_length + # If the length of the current sentence plus the length of the previous and next + # chunks exceeds the max_tokens, start popping tokens from the previous and next + # chunks until either total length < max_tokens or the number of tokens in the previous and + # next chunks goes below the default chunk size + while total_length > self._max_tokens and \ + (next_sent_length > self._default_chunk_size or previous_sent_length > self._default_chunk_size): + if next_sent_length >= previous_sent_length: + next_sent.pop() + next_sent_length -= 1 + total_length -= 1 + elif previous_sent_length > next_sent_length: + previous_sent.popleft() + previous_sent_length -= 1 + total_length -= 1 + # If this is not a sub sentence, increment the ID to + # indicate the processing of the next sentence of the note + # If it is a sub sentence, keep the ID the same, to indicate + # it belongs to a larger sentence + if not sub: + self._id_num += 1 + # If total length < max_tokens - process the sentence with the current sentence + # and add on the previous and next chunks and return + if total_length <= self._max_tokens: + # Check if we want to add a pre-defined chunk for the first sentence in the note + if index == 0 and start_chunk is not None: + previous_sent_tokens = [chunk[token_text_key] for chunk in start_chunk] + \ + [prev_token[token_text_key] for prev_token in list(previous_sent)] + else: + previous_sent_tokens = [prev_token[token_text_key] for prev_token in list(previous_sent)] + # Check if we want to add a pre-defined chunk for the last sentence in the note + if index == len(sent_tokens) - 1 and end_chunk is not None: + next_sent_tokens = [next_token[token_text_key] for next_token in list(next_sent)] + \ + [chunk[token_text_key] for chunk in end_chunk] + else: + next_sent_tokens = [next_token[token_text_key] for next_token in list(next_sent)] + previous_sent_length = len(previous_sent_tokens) + next_sent_length = len(next_sent_tokens) + # Store information about the current sentence - start and end pos etc + # this can be used to distinguish from the next and previous chunks + # current_sent_info = {'token_info':current_sent} + # Assign an different label (the ignore label) to the chunks - since they are used only for context + previous_sent_labels = list() + next_sent_labels = list() + if self._ignore_label == 'NA': + previous_sent_labels = [self._ignore_label] * previous_sent_length + next_sent_labels = [self._ignore_label] * next_sent_length + elif self._ignore_label == 'label': + if index == 0 and start_chunk is not None: + previous_sent_labels = [chunk[label_key] for chunk in start_chunk] + \ + [prev_token[label_key] for prev_token in list(previous_sent)] + else: + previous_sent_labels = [prev_token[label_key] for prev_token in list(previous_sent)] + if index == len(sent_tokens) - 1 and end_chunk is not None: + next_sent_labels = [next_token[label_key] for next_token in list(next_sent)] + \ + [chunk[label_key] for chunk in end_chunk] + else: + next_sent_labels = [next_token[label_key] for next_token in list(next_sent)] + # Concatenate the chunks and the sentence + # sent_tokens_text.append(token[token_text_key]) + tokens_data = previous_sent_tokens + sent_tokens_text + next_sent_tokens + labels_data = previous_sent_labels + sent_labels + next_sent_labels + # Return processed sentences + yield self._id_num, {'tokens': tokens_data, 'labels': labels_data, 'current_sent_info': current_sent} + # Process the sub sentences - we take a long sentence + # and split it into smaller chunks - and we recursively call the function on this list + # of smaller chunks - as mentioned before the smaller chunks (sub sentences) will have the + # same ID as the original sentence + else: + # Store the smaller chunks - say is too long + # + # We get chunk sent 1 - to and we pass this [] to the function + # as a recursive call. This list is now processed as a smaller note that essentially belongs + # to a sentence. But as you can see we did not pass & , because + # these are chunks that are not part of the current sentence, but they still need to be + # included in the final output - and the work around is mentioned below + # So that we have a previous chunk for and next chunk for + # we include the previous_sent_tokens and next_sent_tokens as the start chunk + # and the next chunk in the function call below + # , id = x + # , id = x + # , id = x + sub_sentences = list() + # Prefix the first sentence in these smaller chunks + previous_sent_tokens = list(previous_sent) + # Suffix the last sentence in these smaller chunks + next_sent_tokens = list(next_sent) + # Get chunks + for chunk in SentenceDataset.chunker(sent_toks, self._max_tokens - (2 * self._default_chunk_size)): + sub_sentences.append(chunk) + # Process list of smaller chunks + for sub_sent in self.get_sentences( + sub_sentences, + token_text_key, + label_key, + start_chunk=previous_sent_tokens, + end_chunk=next_sent_tokens, + sub=True + ): + yield sub_sent diff --git a/ner_datasets/dataset_creator.py b/ner_datasets/dataset_creator.py new file mode 100644 index 0000000000000000000000000000000000000000..3b6b4a1c714bcaff88587aea1eb6ca7326b5dc55 --- /dev/null +++ b/ner_datasets/dataset_creator.py @@ -0,0 +1,322 @@ +import json +import random +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter +from typing import Iterable, Dict, List, Union, Optional, Sequence, NoReturn + +from .dataset_builder import Dataset, SentenceDataset +from .preprocessing import PreprocessingLoader + +random.seed(41) + + +class DatasetCreator(object): + """ + Build a NER token classification dataset + For training we will build the dataset using the annotated spans (e.g from prodigy) + For predictions we will assign default labels. + The dataset is on a sentence level, i.e each note is split into sentences and the de-id + task is run on a sentence level. Even the predictions are run on a sentence level + The dataset would be something like: + Tokens: [[tok1, tok2, ... tok-n], [tok ...], ..., [tok ...]] + Labels: [[lab1, lab2, ... lab-n], [lab ...], ..., [lab ...]] + Where the inner list represents the sentences - the tokens in the sentence and the respective + labels for each token. The labels depend on the notation + This script can also be used for predictions, the Labels will be filled with some + default value. This is done so that we can use the same script for building a dataset to train a model + and a dataset to obtain predictions using a model + Example: + Note: Bruce Wayne is a 60yo man. He lives in Gotham + Sentences: [Bruce Wayne Jr is a 60yo man., He lives in Gotham] + Tokens: [[Bruce, Wayne, Jr, is, a, 60, yo, man, .], [He, lives, in, Gotham]] + Labels (BIO notation): [[B-Name, I-Name, I-Name, O, O, O, O, O, O], [O, O, O, B-LOC]] + Labels (BILOU notation): [[B-Name, I-Name, L-Name, O, O, O, O, O, O], [O, O, O, U-LOC]] + We also can create sentences that uses previous/next chunks as context - in this case the dataset would + look something like this. (Assume we limit the size of the chunks to 3 tokens) + Sentences: [Bruce Wayne Jr is a 60yo man., He lives in Gotham] + Tokens: [[Bruce, Wayne, Jr, is, a, 60, yo, man, ., He, lives, in], [yo, man, ., He, lives, in, Gotham]] + Labels (BIO notation): [[B-Name, I-Name, I-Name, O, O, O, O, O, O, NA, NA, NA], [NA, NA, NA, O, O, O, B-LOC]] + Labels (BILOU notation): [[B-Name, I-Name, L-Name, O, O, O, O, O, O, NA, NA, NA], [NA, NA, NA, O, O, O, U-LOC]] + NA represents the token is used for context + """ + + def __init__( + self, + sentencizer: str, + tokenizer: str, + abbreviations: Optional[Sequence[str]] = None, + max_tokens: int = 128, + max_prev_sentence_token: int = 32, + max_next_sentence_token: int = 32, + default_chunk_size: int = 32, + ignore_label: str = 'NA' + ) -> NoReturn: + """ + Initialize the sentencizer and tokenizer + Args: + sentencizer (str): Specify which sentencizer you want to use + tokenizer (str): Specify which tokenizer you want to use + abbreviations (Optional[Sequence[str]]): A list of abbreviations for which tokens will not be split + - works only with with custom clinical tokenizer. + max_tokens (int): The maximum number of tokens allowed in a sentence/training example, + truncate if it exceeds. + max_prev_sentence_token (int): The maximum number of previous chunk tokens allowed in a + sentence/training example + max_next_sentence_token (int): The maximum number of next chunk tokens allowed in a + sentence/training example. + ignore_label (str): The label assigned to the previous and next chunks to distinguish + from the current sentence + """ + self._sentencizer = PreprocessingLoader.get_sentencizer(sentencizer=sentencizer) + self._tokenizer = PreprocessingLoader.get_tokenizer(tokenizer=tokenizer, abbreviations=abbreviations) + # Initialize the object that will be used to get the tokens and the sentences + self._dataset = Dataset(sentencizer=self._sentencizer, tokenizer=self._tokenizer) + # Initialize the object that will take all the sentences in the note and return + # a dataset where each row represents a sentence in the note. The sentence in each + # row will also contain a previous chunk and next chunk (tokens) that will act as context + # when training the mode + # [ps1, ps 2, ps 3...ps-i], [cs1, cs2, ... cs-j], [ns, ns, ... ns-k] - as you can see the current sentence + # which is the sentence we train on (or predict on) will be in the middle - the surrounding tokens will + # provide context to the current sentence + self._sentence_dataset = SentenceDataset( + max_tokens=max_tokens, + max_prev_sentence_token=max_prev_sentence_token, + max_next_sentence_token=max_next_sentence_token, + default_chunk_size=default_chunk_size, + ignore_label=ignore_label + ) + + def create( + self, + input_file: str, + mode: str = 'predict', + notation: str = 'BIO', + token_text_key: str = 'text', + metadata_key: str = 'meta', + note_id_key: str = 'note_id', + label_key: str = 'labels', + span_text_key: str = 'spans' + ) -> Iterable[Dict[str, Union[List[Dict[str, Union[str, int]]], List[str]]]]: + """ + This function is used to get the sentences that will be part of the NER dataset. + We check whether the note belongs to the desired dataset split. If it does, + we fix any spans that can cause token-span alignment errors. Then we extract + all the sentences in the notes, the tokens in each sentence. Finally we + add some context tokens to the sentence if required. This function returns + an iterable that iterated through each of the processed sentences + Args: + input_file (str): Input jsonl file. Make sure the spans are in ascending order (based on start position) + mode (str): Dataset being built for train or predict. + notation (str): The NER labelling notation + token_text_key (str): The key where the note text and token text is present in the json object + metadata_key (str): The key where the note metadata is present in the json object + note_id_key (str): The key where the note id is present in the json object + label_key (str): The key where the token label will be stored in the json object + span_text_key (str): The key where the note spans is present in the json object + Returns: + (Iterable[Dict[str, Union[List[Dict[str, Union[str, int]]], List[str]]]]): Iterate through the processed + sentences/training examples + """ + # Go through the notes + for line in open(input_file, 'r'): + note = json.loads(line) + note_text = note[token_text_key] + note_id = note[metadata_key][note_id_key] + if mode == 'train': + note_spans = note[span_text_key] + # No spans in predict mode + elif mode == 'predict': + note_spans = None + else: + raise ValueError("Invalid mode - can only be train/predict") + # Store the list of tokens in the sentence + # Eventually this list will contain all the tokens in the note (split on the sentence level) + # Store the start and end positions of the sentence in the note. This can + # be used later to reconstruct the note from the sentences + # we also store the note_id for each sentence so that we can map it back + # to the note and therefore have all the sentences mapped back to the notes they belong to. + sent_tokens = [sent_tok for sent_tok in self._dataset.get_tokens( + text=note_text, + spans=note_spans, + notation=notation + )] + # The following loop goes through each sentence in the note and returns + # the current sentence and previous and next chunks that will be used for context + # The chunks will have a default label (e.g NA) to distinguish from the current sentence + # and so that we can ignore these chunks when calculating loss and updating weights + # during training + for ner_sent_index, ner_sentence in self._sentence_dataset.get_sentences( + sent_tokens=sent_tokens, + token_text_key=token_text_key, + label_key=label_key + ): + # Return the processed sentence. This sentence will then be used + # by the model + current_sent_info = ner_sentence['current_sent_info'] + note_sent_info_store = {'start': current_sent_info[0]['start'], + 'end': current_sent_info[-1]['end'], 'note_id': note_id} + ner_sentence['note_sent_info'] = note_sent_info_store + yield ner_sentence + + +def main(): + cli_parser = ArgumentParser( + description='configuration arguments provided at run time from the CLI', + formatter_class=ArgumentDefaultsHelpFormatter + ) + cli_parser.add_argument( + '--input_file', + type=str, + required=True, + help='the the jsonl file that contains the notes. spans need to be sorted in ascending order (based on start ' + 'position) ' + ) + cli_parser.add_argument( + '--notation', + type=str, + default='BIO', + help='the notation we will be using for the label scheme' + ) + cli_parser.add_argument( + '--max_tokens', + type=int, + default=128, + help='The max tokens that a given sentence (training/prediction example) in the note can have' + ) + cli_parser.add_argument( + '--default_chunk_size', + type=int, + default=32, + help='the default chunk size for the previous and next chunks for a given sentence (training/prediction ' + 'example) in the note can have ' + ) + cli_parser.add_argument( + '--max_prev_sentence_token', + type=int, + default=32, + help='the max chunk size for the previous chunks for a given sentence (training/prediction example) in the ' + 'note can have ' + ) + cli_parser.add_argument( + '--max_next_sentence_token', + type=int, + default=32, + help='the max chunk size for the next chunks for a given sentence (training/prediction example) in the note ' + 'can have ' + ) + cli_parser.add_argument( + '--mode', + type=str, + choices=['train', 'predict'], + required=True, + help='whether we are building the dataset for training or prediction' + ) + cli_parser.add_argument( + '--sentencizer', + type=str, + required=True, + help='the sentencizer to use for splitting notes into sentences' + ) + cli_parser.add_argument( + '--tokenizer', + type=str, + required=True, + help='the tokenizer to use for splitting text into tokens' + ) + cli_parser.add_argument( + '--abbreviations', + type=str, + default=None, + help='file that will be used by clinical tokenizer to handle abbreviations' + ) + cli_parser.add_argument( + '--ignore_label', + type=str, + default='NA', + help='whether to use the ignore label or not' + ) + cli_parser.add_argument( + '--token_text_key', + type=str, + default='text', + help='the key where the note text is present in the json object' + ) + cli_parser.add_argument( + '--metadata_key', + type=str, + default='meta', + help='the key where the note metadata is present in the json object' + ) + cli_parser.add_argument( + '--note_id_key', + type=str, + default='note_id', + help='the key where the note metadata is present in the json object' + ) + cli_parser.add_argument( + '--label_key', + type=str, + default='label', + help='the key where the note label for each token is present in the json object' + ) + cli_parser.add_argument( + '--span_text_key', + type=str, + default='spans', + help='the key where the note annotates spans are present in the json object' + ) + cli_parser.add_argument( + '--format', + type=str, + default='jsonl', + help='format to store the dataset in: jsonl or conll' + ) + cli_parser.add_argument( + '--output_file', + type=str, + help='The file where the NER dataset will be stored' + ) + args = cli_parser.parse_args() + dataset_creator = DatasetCreator( + sentencizer=args.sentencizer, + tokenizer=args.tokenizer, + abbreviations=args.abbreviations, + max_tokens=args.max_tokens, + max_prev_sentence_token=args.max_prev_sentence_token, + max_next_sentence_token=args.max_next_sentence_token, + default_chunk_size=args.default_chunk_size, + ignore_label=args.ignore_label) + ner_notes = dataset_creator.create( + input_file=args.input_file, + mode=args.mode, + notation=args.notation, + token_text_key=args.token_text_key, + metadata_key=args.metadata_key, + note_id_key=args.note_id_key, + label_key=args.label_key, + span_text_key=args.span_text_key + ) + # Store the NER dataset in the desired format + if args.format == 'jsonl': + # Write the dataset to the output file + with open(args.output_file, 'w') as file: + for ner_sentence in ner_notes: + file.write(json.dumps(ner_sentence) + '\n') + elif args.format == 'conll': + with open(args.output_file, 'w') as file: + for ner_sentence in ner_notes: + tokens = ner_sentence['tokens'] + labels = ner_sentence['labels'] + current_sent_info = ner_sentence['current_sent_info'] + note_id = ner_sentence['note_sent_info']['note_id'] + if len(tokens) != len(labels) or len(labels) != len(current_sent_info): + raise ValueError('Length mismatch') + for token, label, sent_info in zip(tokens, labels, current_sent_info): + sent_info['note_id'] = note_id + data = token + ' ' + label + ' ' + json.dumps(sent_info) + '\n' + file.write(data) + file.write('\n') + + +if __name__ == '__main__': + + main() diff --git a/ner_datasets/dataset_splitter.py b/ner_datasets/dataset_splitter.py new file mode 100644 index 0000000000000000000000000000000000000000..285e22a25162c227feb586855f8c097aa0b13f44 --- /dev/null +++ b/ner_datasets/dataset_splitter.py @@ -0,0 +1,294 @@ +import json +import random +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter +from collections import Counter +from typing import NoReturn, List + +from .distribution import NERDistribution, DatasetSplits, PrintDistribution + +random.seed(41) + + +class DatasetSplitter(object): + """ + Prepare dataset splits - training, validation & testing splits + Compute ner distributions in our dataset. Compute ner distributions + based on which we create and store a dictionary which will contain + information about which notes (in a dataset) belong to which split. + Based on this distribution and whether we want to keep certain notes + grouped (e.g by patient) we assign notes to a split, such that the + final ner type distribution in each split is similar. + """ + + def __init__( + self, + train_proportion: int = 70, + validation_proportion: int = 15, + test_proportion: int = 15 + ) -> NoReturn: + """ + Initialize the proportions of the splits. + Args: + train_proportion (int): Ratio of train dataset + validation_proportion (int): Ratio of validation dataset + test_proportion (int): Ratio of test dataset + """ + self._train_proportion = train_proportion + self._validation_proportion = validation_proportion + self._test_proportion = test_proportion + self._split = None + self._lookup_split = dict() + + def get_split(self, split: str) -> List[str]: + return [key for key in self._lookup_split[split].keys()] + + def set_split(self, split: str) -> NoReturn: + """ + Set the split that you are currently checking/processing. + Based on the split you can perform certain checks and + computation. Once the split is set, read the information + present in the split_info_path. Extract only the information + belonging to the split. Create a hash map where we have + the keys as the note_ids/patient ids that belong to the split. This hashmap + can then be used to check if a particular note belongs to this + split. + Args: + split (str): The split - train, test etc (depends on how you named it) + """ + if split not in ['train', 'validation', 'test']: + raise ValueError('Invalid split') + self._split = split + + def __update_split(self, key: str) -> NoReturn: + """ + Update the hash map where we have + the keys (e.g note_id) that belong to the split. This hashmap + can then be used to check if a particular note belongs to this + split. + Args: + key (str): The key that identify the note belonging to the split + """ + self._lookup_split[self._split][key] = 1 + + def check_note(self, key: str) -> bool: + """ + Use the hash map created in the __get_i2b2_filter_map function + to check if the note (note_info) belongs to this split (train, + val, test etc). If it does, return true, else false + Args: + key (str): The key that identify the note belonging to the split + Returns: + (bool): True if the note belongs to the split, false otherwise + """ + if self._split is None: + raise ValueError('Split not set') + if self._lookup_split[self._split].get(key, False): + return True + else: + return False + + def assign_splits( + self, + input_file: str, + spans_key: str = 'spans', + metadata_key: str = 'meta', + group_key: str = 'note_id', + margin: float = 0.3 + ) -> NoReturn: + """ + Get the dataset splits - training, validation & testing splits + Based on the NER distribution and whether we want to keep certain + notes grouped (e.g by patient). Return an iterable that contains + a tuple that contains the note_id and the split. This can be used + to filter notes based on the splits. + Args: + input_file (str): The input file + spans_key (str): The key where the note spans are present + metadata_key (str): The key where the note metadata is present + group_key (str): The key where the note group (e.g note_id or patient id etc) is present. + This field is what the notes will be grouped by, and all notes belonging + to this grouping will be in the same split + margin (float): Margin of error when maintaining proportions in the splits + """ + # Compute the distribution of NER types in the grouped notes. + # For example the distribution of NER types in all notes belonging to a + # particular patient + self._lookup_split = { + 'train': dict(), + 'validation': dict(), + 'test': dict() + } + ner_distribution = NERDistribution() + for line in open(input_file, 'r'): + note = json.loads(line) + key = note[metadata_key][group_key] + ner_distribution.update_distribution(spans=note[spans_key], key=key) + # Initialize the dataset splits object + dataset_splits = DatasetSplits( + ner_distribution=ner_distribution, + train_proportion=self._train_proportion, + validation_proportion=self._validation_proportion, + test_proportion=self._test_proportion, + margin=margin + ) + # Check the note and assign it to a split + for line in open(input_file, 'r'): + note = json.loads(line) + key = note[metadata_key][group_key] + split = dataset_splits.get_split(key=key) + self.set_split(split) + self.__update_split(key) + return None + + +def main() -> NoReturn: + """ + Prepare dataset splits - training, validation & testing splits + Compute ner distributions in our dataset. Based on this distribution + and whether we want to keep certain notes grouped (e.g by patient) + we assign notes to a split, such that the final ner type distribution + in each split is similar. + """ + # Compute the distribution of NER types in the grouped notes. + # For example the distribution of NER types in all notes belonging to a + # particular patient + # The following code sets up the arguments to be passed via CLI or via a JSON file + cli_parser = ArgumentParser( + description='configuration arguments provided at run time from the CLI', + formatter_class=ArgumentDefaultsHelpFormatter + ) + cli_parser.add_argument( + '--input_file', + type=str, + required=True, + help='the the jsonl file that contains the notes' + ) + cli_parser.add_argument( + '--spans_key', + type=str, + default='spans', + help='the key where the note spans is present in the json object' + ) + cli_parser.add_argument( + '--metadata_key', + type=str, + default='meta', + help='the key where the note metadata is present in the json object' + ) + cli_parser.add_argument( + '--group_key', + type=str, + default='note_id', + help='the key to group notes by in the json object' + ) + cli_parser.add_argument( + '--train_proportion', + type=int, + default=70, + help='ratio of train dataset' + ) + cli_parser.add_argument( + '--train_file', + type=str, + default=None, + help='The file to store the train data' + ) + cli_parser.add_argument( + '--validation_proportion', + type=int, + default=15, + help='ratio of validation dataset' + ) + cli_parser.add_argument( + '--validation_file', + type=str, + default=None, + help='The file to store the validation data' + ) + cli_parser.add_argument( + '--test_proportion', + type=int, + default=15, + help='ratio of test dataset' + ) + cli_parser.add_argument( + '--test_file', + type=str, + default=None, + help='The file to store the test data' + ) + cli_parser.add_argument( + '--margin', + type=float, + default=0.3, + help='margin of error when maintaining proportions in the splits' + ) + cli_parser.add_argument( + '--print_dist', + action='store_true', + help='whether to print the label distribution in the splits' + ) + args = cli_parser.parse_args() + dataset_splitter = DatasetSplitter( + train_proportion=args.train_proportion, + validation_proportion=args.validation_proportion, + test_proportion=args.test_proportion + ) + dataset_splitter.assign_splits( + input_file=args.input_file, + spans_key=args.spans_key, + metadata_key=args.metadata_key, + group_key=args.group_key, + margin=args.margin + ) + + if args.train_proportion > 0: + with open(args.train_file, 'w') as file: + for line in open(args.input_file, 'r'): + note = json.loads(line) + key = note[args.metadata_key][args.group_key] + dataset_splitter.set_split('train') + if dataset_splitter.check_note(key): + file.write(json.dumps(note) + '\n') + + if args.validation_proportion > 0: + with open(args.validation_file, 'w') as file: + for line in open(args.input_file, 'r'): + note = json.loads(line) + key = note[args.metadata_key][args.group_key] + dataset_splitter.set_split('validation') + if dataset_splitter.check_note(key): + file.write(json.dumps(note) + '\n') + + if args.test_proportion > 0: + with open(args.test_file, 'w') as file: + for line in open(args.input_file, 'r'): + note = json.loads(line) + key = note[args.metadata_key][args.group_key] + dataset_splitter.set_split('test') + if dataset_splitter.check_note(key): + file.write(json.dumps(note) + '\n') + + if args.print_dist: + # Read the dataset splits file and compute the NER type distribution + key_counts = Counter() + ner_distribution = NERDistribution() + for line in open(args.input_file, 'r'): + note = json.loads(line) + key = note[args.metadata_key][args.group_key] + key_counts[key] += 1 + ner_distribution.update_distribution(spans=note[args.spans_key], key=key) + print_distribution = PrintDistribution(ner_distribution=ner_distribution, key_counts=key_counts) + train_splits = dataset_splitter.get_split('train') + validation_splits = dataset_splitter.get_split('validation') + test_splits = dataset_splitter.get_split('test') + all_splits = train_splits + validation_splits + test_splits + # Print distribution for each split + print_distribution.split_distribution(split='total', split_info=all_splits) + print_distribution.split_distribution(split='train', split_info=train_splits) + print_distribution.split_distribution(split='validation', split_info=validation_splits) + print_distribution.split_distribution(split='test', split_info=test_splits) + + +if __name__ == "__main__": + main() diff --git a/ner_datasets/distribution/__init__.py b/ner_datasets/distribution/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9dfaae4af850601cf8969aff791d8440cc467da1 --- /dev/null +++ b/ner_datasets/distribution/__init__.py @@ -0,0 +1,4 @@ +from .dataset_splits import DatasetSplits +from .ner_distribution import NERDistribution +from .print_distribution import PrintDistribution +__all__=["DatasetSplits", "NERDistribution", "PrintDistribution"] \ No newline at end of file diff --git a/ner_datasets/distribution/dataset_splits.py b/ner_datasets/distribution/dataset_splits.py new file mode 100644 index 0000000000000000000000000000000000000000..4221b3f9441dbdc79716053c5543cc2f41299200 --- /dev/null +++ b/ner_datasets/distribution/dataset_splits.py @@ -0,0 +1,218 @@ +import random +from collections import Counter +from typing import NoReturn + +from .ner_distribution import NERDistribution + +random.seed(41) + + +class DatasetSplits(object): + """ + Prepare dataset splits - training, validation & testing splits + Compute ner distributions in the dataset. Based on this we assign + notes to different splits and at the same time, we keep the distribution of + NER types in each split similar. . + Keep track of the split information - which notes are present in which split. + The label distribution in each split, the number of notes in each split. + """ + + def __init__( + self, + ner_distribution: NERDistribution, + train_proportion: int, + validation_proportion: int, + test_proportion: int, + margin: float + ) -> NoReturn: + """ + Maintain split information. Assign notes based on the proportion of + the splits, while keeping the label distribution in each split similar. + Keep track of the split information - which notes are present in which split. + The label distribution in each split, the number of notes in each split. + Keep track of the dataset splits and the counts in each split etc. + These will be used to assign the different notes to different + splits while keeping the proportion of ner similar in each split. + Get the maximum number of ner that can be present in the train, + validation and test split. The total count will be used to + calculate the current proportion of ner in the split. This can be used + to keep the proportion of ner types consistent among different splits + Args: + ner_distribution (NERDistribution): The NER distribution in the dataset + train_proportion (int): Ratio of train dataset + validation_proportion (int): Ratio of validation dataset + test_proportion (int): Ratio of test dataset + margin (float): Margin by which the label distribution can be exceeded in the split + """ + self._ner_distribution = ner_distribution + # Compute the counts of NER types in the entire dataset + total_distribution = Counter() + for key, counts in ner_distribution.get_ner_distribution().items(): + for label, count in counts.items(): + total_distribution[label] += count + # Compute the percentages of NER types in the entire dataset + self._total_ner = sum(total_distribution.values()) + self._label_dist_percentages = { + ner_type: float(count) / self._total_ner * 100 if self._total_ner else 0 + for ner_type, count in total_distribution.items() + } + self._margin = margin + # The three splits + self._splits = ['train', 'validation', 'test'] + self._split_weights = None + self._splits_info = None + # Keep track of the patient_ids that have been processed. + # Since a patient can have multiple notes and we already know the + # ner distribution for this patient across all the notes (i.e the ner types + # and count that appear in all the notes associated with this patient) + # We also keep all the notes associated with a patient in the same split + # So we check if adding all the notes associated with this patient will + # disturb the ner distribution (proportions) as mentioned before. + self._processed_keys = dict() + # Based on these proportions we compute train_ner_count, validation_ner_count, test_ner_count + # Say the proportion are 85, 10, 5 + # The train split will have a maximum of 85% of the overall ner, validation will have 10 and test will 5 + # That is if there are total count of all ner is 100, on splitting the datasets + # the train split will have a total of 85 ner, validation split will have a total of 10 ner and the + # test split will have a total of 5 ner + train_ner_count = int(train_proportion * self._total_ner / 100) + validation_ner_count = int(validation_proportion * self._total_ner / 100) + test_ner_count = int(test_proportion * self._total_ner / 100) + # So based on this, we check if adding a note keeps the balance in proportion or not + # If it does not, we check the splits given in the "remain" field of the dict (which is + # the 2 other splits + self._split_weights = [train_proportion, validation_proportion, test_proportion] + # Based on the split proportions, ner counts and ner distribution + # we need to split our dataset into train, validation and test split + # For each split we try and maintain the same distribution (proportions) between ner types + # that we computed from the entire dataset (given by - ner_distribution) + # If the entire dataset had AGE:50%, DATE:30%, LOC:20%, we want the same proportions + # in each of the train, validation and test splits + # So based on this, we check if adding a note keeps the balance in proportion or not + # If it does not, we check the splits given in the "remain" field of the dict (which is + # the 2 other splits + self._splits_info = {'train': {'remain': ['validation', 'test'], + 'total': train_ner_count, + 'remain_weights': [validation_proportion, test_proportion], + 'groups': list(), 'number_of_notes': 0, 'label_dist': Counter()}, + 'validation': {'remain': ['train', 'test'], + 'total': validation_ner_count, + 'remain_weights': [train_proportion, test_proportion], + 'groups': list(), 'number_of_notes': 0, 'label_dist': Counter()}, + 'test': {'remain': ['validation', 'train'], + 'total': test_ner_count, + 'remain_weights': [validation_proportion, train_proportion], + 'groups': list(), 'number_of_notes': 0, 'label_dist': Counter()}} + + def __set_split(self, split: str) -> NoReturn: + """ + Set the split that you are currently checking/processing. + Based on the split you can perform certain checks and + computation for that split. + Args: + split (str): The split - train, validation or test + """ + self._split = split + + def __update_label_dist(self, distribution: Counter) -> NoReturn: + """ + Once we have determined that a note can be added to the split we need to + update the current count of the ner types in the split. So we pass the ner counts + in the note that will be updated and update the counts of the ner types in the split. + Args: + distribution (Counter): Contains the ner type and it's counts (distribution) + """ + self._splits_info[self._split]['label_dist'].update(distribution) + + def __update_groups(self, note_group_key: str) -> NoReturn: + """ + Once we have determined that a note can be added to the split, we append + to a list some distinct element of the note (e.g note_id). This list will + contain the note_ids of the notes that belong to this split. + Args: + note_group_key (str): Contains the note metadata - e.g note_id, institute etc + """ + self._processed_keys[note_group_key] = self._split + self._splits_info[self._split]['groups'].append(note_group_key) + + def __check_split(self, distribution: Counter) -> bool: + """ + This function is used to check the resulting ner distribution in the split on adding this + note to the split. We check how the proportion of ner changes if this note is added to + the split. If the proportion exceeds the desired proportion then we return false + to indicate that adding this note will upset the ner distribution across splits, so we should + instead check adding this note to another split. If it does not update the balance then we return + True, which means we can add this note to this split. The desired proportion of ner is passed + in the percentages argument - where we have the desired proportion for each ner type. + Args: + distribution (Counter): Contains the mapping between ner type and count + Returns: + (bool): True if the note can be added to the split, false otherwise + """ + # Get the current ner types and counts in the split + split_label_dist = self._splits_info[self._split]['label_dist'] + # Get the max ner count that can be present in the split + # This will be used to compute the ner proportions in the split + split_total = self._splits_info[self._split]['total'] + # Check if the proportion of the split picked in zero + # and return False because we cant add any note to this split + if split_total == 0: + return False + for ner_type, count in distribution.items(): + percentage = (split_label_dist.get(ner_type, 0) + count) / split_total * 100 + # Check if the proportion on adding this note exceeds the desired proportion + # within the margin of error + # If it does return false + if percentage > self._label_dist_percentages[ner_type] + self._margin: + return False + return True + + def get_split(self, key: str) -> str: + """ + Assign a split to the note - based on the distribution of ner types in the note + and the distribution of ner types in the split. Essentially assign a note to a split + such that the distribution of ner types in each split is similar, once all notes have + been assigned to their respective splits. + Args: + key (str): The note id or patient id of the note (some grouping key) + Returns: + (str): The split + """ + current_splits = self._splits + current_weights = self._split_weights + distribution = self._ner_distribution.get_group_distribution(key=key) + if self._processed_keys.get(key, False): + return self._processed_keys[key] + while True: + # Pick and set the split + check_split = random.choices(current_splits, current_weights)[0] + self.__set_split(check_split) + # Get the ner distribution for this particular patient (across all the notes associated + # with this patient) and check if the notes can be added to this split. + # The margin of error for the ner proportions. As we said above we try and keep the proportions + # across the splits the same, but we allow for some flexibility, so we can go +- the amount + # given by margin. + include = self.__check_split(distribution=distribution) + if include: + self.__update_groups(key) + self.__update_label_dist(distribution=distribution) + return check_split + else: + # Check the two other possible splits + if len(current_splits) == 3: + current_splits = self._splits_info[check_split]['remain'] + current_weights = self._splits_info[check_split]['remain_weights'] + # Check the one other possible split (when the one of the above two other split check returns false) + elif len(current_splits) == 2 and current_weights[1 - current_splits.index(check_split)] != 0: + index = current_splits.index(check_split) + current_splits = [current_splits[1 - index]] + current_weights = [100] + # If it can't be added to any split - choose a split randomly + else: + current_splits = self._splits + current_weights = self._split_weights + check_split = random.choices(current_splits, current_weights)[0] + self.__set_split(check_split) + self.__update_groups(key) + self.__update_label_dist(distribution=distribution) + return check_split diff --git a/ner_datasets/distribution/ner_distribution.py b/ner_datasets/distribution/ner_distribution.py new file mode 100644 index 0000000000000000000000000000000000000000..9bf9b3e7e771e337f3cade3e5509150013186537 --- /dev/null +++ b/ner_datasets/distribution/ner_distribution.py @@ -0,0 +1,54 @@ +from collections import Counter, defaultdict +from typing import Sequence, Mapping, NoReturn + + +class NERDistribution(object): + """ + Store the distribution of ner types based on some key. + That is we store the NER type distribution for some given key value and we update + the distribution when spans related to that key is passed + """ + + def __init__(self) -> NoReturn: + """ + Initialize the NER type - count mapping + """ + # Counter the captures the ner types and counts per patient/note_id in the dataset + # Depending on what we set the group_key as. Basically gather counts with respect + # to some grouping of the notes + # E.g - {{PATIENT 1: {AGE: 99, DATE: 55, ...}, {PATIENT 2: {AGE: 5, DATE: 9, ...} ... } + self._ner_distribution = defaultdict(Counter) + + def update_distribution(self, spans: Sequence[Mapping[str, str]], key: str) -> NoReturn: + """ + Update the distribution of ner types for the given key + Args: + spans (Sequence[Mapping[str, str]]): The list of spans in the note + key (str): The note id or patient id of the note (some grouping) + """ + # Go through the spans in the note and compute the ner distribution + # Compute both the overall ner distribution and ner distribution per + # patient (i.e the ner types in all the notes associated with the patient) + if not self._ner_distribution.get(key, False): + self._ner_distribution[key] = Counter() + for span in spans: + self._ner_distribution[key][span['label']] += 1 + + def get_ner_distribution(self) -> defaultdict: + """ + Return overall ner distribution. The NER type distribution for every key. + Returns: + ner_distribution (defaultdict(Counter)): Overall NER type distribution for all keys + """ + return self._ner_distribution + + def get_group_distribution(self, key: str) -> Counter: + """ + Return the NER type distribution for the given key + Returns: + (Counter): ner distribution w.r.t some grouping (key) + """ + if key in self._ner_distribution.keys(): + return self._ner_distribution[key] + else: + raise ValueError('Key not found') diff --git a/ner_datasets/distribution/print_distribution.py b/ner_datasets/distribution/print_distribution.py new file mode 100644 index 0000000000000000000000000000000000000000..1c3fb3cedd17e4c492cf2192f069b0c0863eee6d --- /dev/null +++ b/ner_datasets/distribution/print_distribution.py @@ -0,0 +1,49 @@ +from collections import Counter +from typing import Sequence, NoReturn + +from .ner_distribution import NERDistribution + + +class PrintDistribution(object): + """ + This class is used to print the distribution of NER types + """ + + def __init__(self, ner_distribution: NERDistribution, key_counts: Counter) -> NoReturn: + """ + Initialize + Args: + ner_distribution (NERDistribution): NERDistribution object that keeps track of the NER type distributions + key_counts (Counter): Number of keys/groups (e.g note_ids, patient ids etc) + """ + self._ner_distribution = ner_distribution + self._key_counts = key_counts + + def split_distribution(self, split: str, split_info: Sequence[str]) -> NoReturn: + """ + Print NER type distribution + Args: + split (str): The dataset split + split_info (Sequence[str]): The keys belonging to that split + """ + split_distribution = Counter() + number_of_notes = 0 + for key in split_info: + number_of_notes += self._key_counts[key] + split_distribution.update(self._ner_distribution.get_group_distribution(key)) + total_ner = sum(split_distribution.values()) + percentages = {ner_type: float(count) / total_ner * 100 if total_ner else 0 + for ner_type, count in split_distribution.items()} + print('{:^70}'.format('============ ' + split.upper() + ' NER Distribution =============')) + print('{:<20}{:<10}'.format('Number of Notes: ', number_of_notes)) + print('{:<20}{:<10}\n'.format('Number of Groups: ', len(split_info))) + for ner_type, count in split_distribution.most_common(): + print('{:<10}{:<10}{:<5}{:<10}{:<5}{:<10}'.format( + 'NER Type: ', ner_type, + 'Count: ', count, + 'Percentage: ', '{:0.2f}'.format(percentages[ner_type])) + ) + print('{:<10}{:<10}{:<5}{:<10}{:<5}{:<10}'.format( + 'NER Type:', 'TOTALS', 'Count: ', total_ner, 'Percentage: ', '{:0.2f}'.format(100)) + ) + print('\n') diff --git a/ner_datasets/preprocessing/__init__.py b/ner_datasets/preprocessing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8dc8fd7c0b88b04b7dad124fddfff2243ad6eef5 --- /dev/null +++ b/ner_datasets/preprocessing/__init__.py @@ -0,0 +1,2 @@ +from .preprocessing_loader import PreprocessingLoader +__all__ = ["PreprocessingLoader"] diff --git a/ner_datasets/preprocessing/preprocessing_loader.py b/ner_datasets/preprocessing/preprocessing_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..f471eb4c3d6b6794889ab3ca28799767197604a4 --- /dev/null +++ b/ner_datasets/preprocessing/preprocessing_loader.py @@ -0,0 +1,63 @@ +from typing import Union, Optional, Sequence + +from .sentencizers import SpacySentencizer, NoteSentencizer +from .tokenizers import ClinicalSpacyTokenizer, SpacyTokenizer, CoreNLPTokenizer + + +class PreprocessingLoader(object): + + @staticmethod + def get_sentencizer(sentencizer: str) -> Union[SpacySentencizer, NoteSentencizer]: + """ + Initialize the sentencizer and tokenizer based + We can either use the sci-spacy (en_core_sci_lg or en_core_web_sm) or + consider the entire note as a single sentence. + Args: + sentencizer (str): Specify which sentencizer you want to use + Returns: + Union[SpacySentencizer, MimicStanzaSentencizer, NoteSentencizer]: An object of the requested + sentencizer class + """ + if sentencizer == 'en_core_sci_lg': + return SpacySentencizer(spacy_model='en_core_sci_lg') + elif sentencizer == 'en_core_web_sm': + return SpacySentencizer(spacy_model='en_core_web_sm') + elif sentencizer == 'note': + return NoteSentencizer() + else: + raise ValueError('Invalid sentencizer - does not exist') + + @staticmethod + def get_tokenizer( + tokenizer: str, + abbreviations: Optional[Sequence[str]] = None, + ) -> Union[SpacyTokenizer, ClinicalSpacyTokenizer, CoreNLPTokenizer]: + """ + Initialize the tokenizer based on the CLI arguments + We can either use the default scipacy (en_core_sci_lg or en_core_web_sm) + or the modified scipacy (with regex rule) tokenizer. + It also supports the corenlp tokenizer + Args: + tokenizer (str): Specify which tokenizer you want to use + abbreviations (Optional[str]): A list of abbreviations for which tokens will not be split - works only with + used with custom clinical tokenizer + Returns: + Union[SpacyTokenizer, ClinicalSpacyTokenizer, CoreNLPTokenizer]: An object of the requested tokenizer class + """ + if tokenizer == 'en_core_sci_lg': + return SpacyTokenizer(spacy_model='en_core_sci_lg') + elif tokenizer == 'en_core_web_sm': + return SpacyTokenizer(spacy_model='en_core_web_sm') + elif tokenizer == 'en': + return SpacyTokenizer(spacy_model='en') + elif tokenizer == 'corenlp': + return CoreNLPTokenizer() + elif tokenizer == 'clinical': + # Abbreviations - we won't split tokens that match these (e.g 18F-FDG) + if abbreviations is None: + return ClinicalSpacyTokenizer(spacy_model='en_core_sci_lg', abbreviations=abbreviations) + else: + + return ClinicalSpacyTokenizer(spacy_model='en_core_sci_lg', abbreviations=abbreviations) + else: + raise ValueError('Invalid tokenizer - does not exist') diff --git a/ner_datasets/preprocessing/sentencizers/__init__.py b/ner_datasets/preprocessing/sentencizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..66731beed6edaaa89f7cb5b9d7a813c2f47f9a2f --- /dev/null +++ b/ner_datasets/preprocessing/sentencizers/__init__.py @@ -0,0 +1,3 @@ +from .note_sentencizer import NoteSentencizer +from .spacy_sentencizer import SpacySentencizer +__all__=["NoteSentencizer", "SpacySentencizer"] \ No newline at end of file diff --git a/ner_datasets/preprocessing/sentencizers/mimic_stanza_sentencizer.py b/ner_datasets/preprocessing/sentencizers/mimic_stanza_sentencizer.py new file mode 100644 index 0000000000000000000000000000000000000000..f6e59caa775398fd51717e8a247673bac6695735 --- /dev/null +++ b/ner_datasets/preprocessing/sentencizers/mimic_stanza_sentencizer.py @@ -0,0 +1,37 @@ +from typing import Iterable, Dict, Union + +import stanza + + +class MimicStanzaSentencizer(object): + """ + This class is used to read text and split it into + sentences (and their start and end positions) + using the mimic stanza package + """ + + def __init__(self, package: str): + """ + Initialize a mimic stanza model to read text and split it into + sentences. + Args: + package (str): Name of the mimic model + """ + self._nlp = stanza.Pipeline('en', package=package, processors='tokenize', use_gpu=True) + + def get_sentences(self, text: str) -> Iterable[Dict[str, Union[str, int]]]: + """ + Return an integrator that iterates through the sentences in the text + Args: + text (str): The text + Returns: + (Iterable[Dict[str, Union[str, int]]]): Yields a dictionary that contains the text of the sentence + the start position of the sentence in the entire text + and the end position of the sentence in the entire text + """ + doc = self._nlp(text) + for sentence in doc.sentences: + yield {'text': sentence.text, + 'start': sentence.tokens[0].start_char, + 'end': sentence.tokens[-1].end_char, + 'last_token': sentence.tokens[-1].text} diff --git a/ner_datasets/preprocessing/sentencizers/note_sentencizer.py b/ner_datasets/preprocessing/sentencizers/note_sentencizer.py new file mode 100644 index 0000000000000000000000000000000000000000..f11e09da520bb6011b75c6580636506a40914198 --- /dev/null +++ b/ner_datasets/preprocessing/sentencizers/note_sentencizer.py @@ -0,0 +1,33 @@ +from typing import Iterable, Dict, Union + + +class NoteSentencizer(object): + """ + This class is used to read text and split it into + sentences (and their start and end positions) + This class considers an entire note or text as + a single sentence + """ + + def __init__(self): + """ + Nothing to be initialized + """ + + def get_sentences(self, text: str) -> Iterable[Dict[str, Union[str, int]]]: + """ + Return an iterator that iterates through the sentences in the text. + In this case it just returns the text itself. + Args: + text (str): The text + Returns: + (Iterable[Dict[str, Union[str, int]]]): Yields a dictionary that contains the text of the sentence + the start position of the sentence in the entire text + and the end position of the sentence in the entire text + """ + yield { + 'text': text, + 'start': 0, + 'end': len(text), + 'last_token': None + } diff --git a/ner_datasets/preprocessing/sentencizers/spacy_sentencizer.py b/ner_datasets/preprocessing/sentencizers/spacy_sentencizer.py new file mode 100644 index 0000000000000000000000000000000000000000..54cc76ee8270b9179dca5c528424136533e350bf --- /dev/null +++ b/ner_datasets/preprocessing/sentencizers/spacy_sentencizer.py @@ -0,0 +1,37 @@ +from typing import Iterable, Dict, Union + +import spacy + + +class SpacySentencizer(object): + """ + This class is used to read text and split it into + sentences (and their start and end positions) + using a spacy model + """ + + def __init__(self, spacy_model: str): + """ + Initialize a spacy model to read text and split it into + sentences. + Args: + spacy_model (str): Name of the spacy model + """ + self._nlp = spacy.load(spacy_model) + + def get_sentences(self, text: str) -> Iterable[Dict[str, Union[str, int]]]: + """ + Return an iterator that iterates through the sentences in the text + Args: + text (str): The text + Returns: + (Iterable[Dict[str, Union[str, int]]]): Yields a dictionary that contains the text of the sentence + the start position of the sentence in the entire text + and the end position of the sentence in the entire text + """ + document = self._nlp(text) + for sentence in document.sents: + yield {'text': sentence.text, + 'start': sentence.start_char, + 'end': sentence.end_char, + 'last_token': None} diff --git a/ner_datasets/preprocessing/tokenizers/__init__.py b/ner_datasets/preprocessing/tokenizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8ef4768b708d025f4b151aa5d5c6cb11b6b97fb2 --- /dev/null +++ b/ner_datasets/preprocessing/tokenizers/__init__.py @@ -0,0 +1,4 @@ +from .spacy_tokenizer import SpacyTokenizer +from .core_nlp_tokenizer import CoreNLPTokenizer +from .clinical_spacy_tokenizer import ClinicalSpacyTokenizer +__all__=["SpacyTokenizer", "CoreNLPTokenizer", "ClinicalSpacyTokenizer"] \ No newline at end of file diff --git a/ner_datasets/preprocessing/tokenizers/abbreviations/check.txt b/ner_datasets/preprocessing/tokenizers/abbreviations/check.txt new file mode 100644 index 0000000000000000000000000000000000000000..b268f9a611b0fc970f8ff006260cda845ecfad4c --- /dev/null +++ b/ner_datasets/preprocessing/tokenizers/abbreviations/check.txt @@ -0,0 +1,20 @@ +sec. +secs. +Sec. +Secs. +fig. +figs. +Fig. +Figs. +eq. +eqs. +Eq. +Eqs. +no. +nos. +No. +Nos. +al. +gen. +sp. +nov. diff --git a/ner_datasets/preprocessing/tokenizers/abbreviations/medical_abbreviations_curated.txt b/ner_datasets/preprocessing/tokenizers/abbreviations/medical_abbreviations_curated.txt new file mode 100644 index 0000000000000000000000000000000000000000..69bbce09d09ef55af38dd75a107617ad6a097215 --- /dev/null +++ b/ner_datasets/preprocessing/tokenizers/abbreviations/medical_abbreviations_curated.txt @@ -0,0 +1,87 @@ +-ve ++ve +a.c. +a/g +b.i.d. +C&S +C/O +D/C +D&C +D and C +H&H +H&P +h.s. +H/O +h/o +I&D +M/H +N/V +O&P +O.D. +O.S. +O.U. +p¯ +p.o. +p.r.n. +q.d. +q.i.d. +R/O +s/p +T&A +t.i.d. +u/a +u** +y.o. +F/u +Crohn's +R.N. +S/p +S/P +s/P +N/A +n/a +N/a +n/A +w/ +Pt. +pt. +PT. +cf. +CF. +Cf. +dr. +DR. +Dr. +ft. +FT. +Ft. +lt. +LT. +Lt. +mr. +MR. +Mr. +ms. +MS. +Ms. +mt. +MT. +Mt. +mx. +MX. +Mx. +ph. +PH. +Ph. +rd. +RD. +Rd. +st. +ST. +St. +vs. +VS. +Vs. +wm. +WM. +Wm. \ No newline at end of file diff --git a/ner_datasets/preprocessing/tokenizers/abbreviations/medical_abbreviations_wiki.txt b/ner_datasets/preprocessing/tokenizers/abbreviations/medical_abbreviations_wiki.txt new file mode 100644 index 0000000000000000000000000000000000000000..86e52b164e21013088dbedbf9c697317864af632 --- /dev/null +++ b/ner_datasets/preprocessing/tokenizers/abbreviations/medical_abbreviations_wiki.txt @@ -0,0 +1,459 @@ ++ve +x/12 +x/40 +x/52 +x/7 +18F-FDG +2° +2/2 +3TC +5-FU +5-HIAA +5-HT +6MP +a.a. +A1C +Aa. +AAOx3 +A/B +a.c. +AC&BC +ad. +part. +A+E +AF-AFP +a.h. +altern. +d. +Anti- +A&O +A/O +A&Ox3 +A&Ox4 +a.p. +A&P +A/P +applic. +aq. +bull. +calid. +dist. +gel. +ASC-H +ASC-US +A-T +AT-III +aur. +dextro. +aurist. +A&W +A/W +b.i.d. +b/l +bl.cult +B/O +BRCA1 +BRCA2 +C1 +C2 +c/b +CBC/DIFF +C/C/E +CCK-PZ +CHEM-7 +CHEM-20 +C/O +c/o +CO2 +COX-1 +COX-2 +COX-3 +C/S +C&S +C-section +C-spine +C-SSRS +c/a/p +c/w +D5 +D25 +D4T +D5W +D&C +D/C +D&E +DHEA-S +Di-Di +DM2 +D/O +D/T +Ex-n +F/C +F/C/S +FEF25–75 +FEV1 +fl.oz. +FTA-ABS +F/U +G6PD +G-CSF +GM-CSF +H/A +HbA1c +HCO3 +HDL-C +H&E +H/H +H&H +H&M +HMG-CoA +H-mole +H/O +H&P +H/oPI +h.s. +I131 +ICD-10 +I&D +IgG4-RD +IgG4-RKD +IgG4-ROD +IgG4-TIN +INF(-α/-β/-γ) +I&O +IV-DSA +L&D +LDL-C +L-DOPA +L/S +MC&S +M/E +MgSO4 +MHA-TP +M&M +MMR-D +Mono-Di +Mono-Mono +MS-AFP +MSO4 +MVo2 +No. +rep. +n.s. +n/t +N&V +n/v +O2 +OB-GYN +ob-gyne +O/E +O/N +O&P +P&A +PAI-1 +PAPP-A +p.c. +PIG-A +PM&R +p.r. +Pt. +p.v. +P-Y +q2wk +q6h +q6° +q.a.d. +q.AM +q.d. +q.d.s. +q.h. +q.h.s. +q.i.d. +q.l. +q.m.t. +q.n. +q.n.s. +q.o.d. +q.o.h. +q.s. +q.v. +q.wk. +r/g/m +R&M +R/O +r/r/w +R/t +RT-PCR +S1 +S2 +S3 +S4 +S&O +S.D. +op. +SMA-6 +SMA-7 +s/p +spp. +Sp. +fl. +gr. +S/S +S/Sx +Staph. +Strep. +Strepto. +T&A +T&C +T&S +TAH-BSO +T2DM +T/F +T&H +Tib-Fib +TRF'd +TSHR-Ab +T.S.T.H. +U/A +U&E +U/O +V-fib +V/Q +WAIS-R +W/C +WISC-R +W/O +w/o +w/u +X-AFP +y/o +a.c.h.s. +ac&hs +a.d. +ad. +add. +lib. +admov. +us. +æq. +agit. +alt. +d. +dieb. +h. +hor. +a.m. +amp. +com. +dest. +ferv. +a.l. +a.s. +a.u. +b.d.s. +bib. +b.i.d. +b.d. +ind. +bol. +Ph.Br. +b.t. +bucc. +cap. +caps. +cap. +c.m. +c.m.s. +c. +cib. +c.c. +cf. +c.n. +cochl. +ampl. +infant. +mag. +mod. +parv. +colet. +comp. +contin. +cpt. +cr. +cuj. +c.v. +cyath. +vinos. +D5LR +D5NS +D5W +D10W +D10W +D/C +decoct. +det. +dil. +dim. +p. +æ. +disp. +div. +d.t.d. +elix. +e.m.p. +emuls. +exhib. +f. +f.h. +fl. +fld. +f.m. +pil. +f.s.a. +ft. +garg. +gutt. +habt. +decub. +intermed. +tert. +inj. +i.m. +inf. +i.v. +i.v.p. +lat. +dol. +lb. +l.c.d. +liq. +lot. +M. +m. +max. +m.d.u. +mg/dL +min. +mist. +mit. +mitt. +præscript. +neb. +noct. +n.p.o. +1/2NS +o 2 +o2 +o.d. +o.m. +omn. +bih. +o.n. +o.s. +o.u. +p.c.h.s. +pc&hs +Ph.Br. +Ph.Eur. +Ph.Int. +pig./pigm. +p.m. +p.o. +ppt. +p.r. +p.r.n. +pt. +pulv. +p.v. +q.1 +q.1° +q4PM +q.a.m. +q.d./q.1.d. +q.d.a.m. +q.d.p.m. +q.p.m. +q.q. +q.q.h. +a.d +rep. +rept. +R/L +s. +s.a. +sem. +s.i.d. +sig. +sing. +s.l. +sol. +s.o.s. +s.s. +st. +sum. +supp. +susp. +syr. +tab. +tal. +t. +t.d.s. +t.i.d. +t.d. +tinct. +t.i.w. +top. +tinc. +trit. +troch. +u.d. +ut. +dict. +ung. +vag. +w/a +w/f +y.o. +ADD-RT +A-T +PDD-NOS +Alzheimer's +Age-related +Aldosterone-producing +Alcohol-related +Ataxia-telangiectasia +Binswanger's +Becker's +Bloom's +Brown-Séquard +Crimean-Congo +Cerebro-oculo-facio-skeletal +Carbapenem-resistant +Drug-resistant +End-stage +Graft-versus-host +Huntington's +High-functioning +Hypoxanthine-guanine +Legionnaires' +Low-functioning +Multi-drug-resistant +Multi-infarct +Machado-Joseph +Maturity-onset +Multi-sensory +Obsessive-compulsive +Parkinson's +kinase-associated +Post-polio +Port-wine +Reye's +Sensory-based +Vitus's +Septo-optic +ST-elevation +Short-lasting +Urticaria-deafness-amyloidosis +Wilson's +drug-resistant +X-linked diff --git a/ner_datasets/preprocessing/tokenizers/clinical_spacy_tokenizer.py b/ner_datasets/preprocessing/tokenizers/clinical_spacy_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..c8867e7382820226d43aabe0b5ab895b7a3d1fe8 --- /dev/null +++ b/ner_datasets/preprocessing/tokenizers/clinical_spacy_tokenizer.py @@ -0,0 +1,73 @@ +import re +import spacy +from spacy.symbols import ORTH +from .spacy_tokenizer import SpacyTokenizer +from .utils import DateRegex, CleanRegex, ClinicalRegex + + +class ClinicalSpacyTokenizer(SpacyTokenizer): + """ + This class is used to read text and return the tokens + present in the text (and their start and end positions) + """ + + def __init__(self, spacy_model, abbreviations, + split_multiple=True, split_temperature=True, + split_percentage=True): + """ + Initialize a spacy model to read text and split it into + tokens. + Args: + spacy_model (str): Name of the spacy model + """ + super().__init__(spacy_model) + self._nlp.tokenizer.prefix_search = self.__get_prefix_regex(split_multiple, split_temperature, + split_percentage).search + self._nlp.tokenizer.infix_finditer = self.__get_infix_regex().finditer + self._nlp.tokenizer.suffix_search = self.__get_suffix_regex().search + new_rules = {} + for orth, exc in self._nlp.tokenizer.rules.items(): + if re.search('((Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Sept|Oct|Nov|Dec)[.]$)|(^(W|w)ed$)', orth): + continue + new_rules[orth] = exc + self._nlp.tokenizer.rules = new_rules + if (abbreviations != None): + for abbreviation in abbreviations: + special_case = [{ORTH: abbreviation}] + self._nlp.tokenizer.add_special_case(abbreviation, special_case) + # this matches any lower case tokens - abstract this part out - whetehr to lowercase abbreviations ro not + exclusions_uncased = {abbreviation.lower(): [{ORTH: abbreviation.lower()}] for abbreviation in + abbreviations} + for k, excl in exclusions_uncased.items(): + try: + self._nlp.tokenizer.add_special_case(k, excl) + except: + print('failed to add exception: {}'.format(k)) + + def __get_prefix_regex(self, split_multiple, split_temperature, split_percentage): + + date_prefix = DateRegex.get_infixes() + clinical_prefix = ClinicalRegex.get_prefixes(split_multiple, split_temperature, split_percentage) + clean_prefix = CleanRegex.get_prefixes() + digit_infix = ClinicalRegex.get_infixes() + prefixes = clean_prefix + self._nlp.Defaults.prefixes + date_prefix + clinical_prefix + digit_infix + prefix_regex = spacy.util.compile_prefix_regex(prefixes) + return prefix_regex + + def __get_suffix_regex(self): + clean_suffix = CleanRegex.get_suffixes() + suffixes = clean_suffix + self._nlp.Defaults.suffixes + suffix_regex = spacy.util.compile_suffix_regex(suffixes) + return suffix_regex + + def __get_infix_regex(self): + + date_infixes = DateRegex.get_infixes() + clean_infixes = CleanRegex.get_infixes() + digit_infix = ClinicalRegex.get_infixes() + infixes = self._nlp.Defaults.infixes + date_infixes + clean_infixes + infix_re = spacy.util.compile_infix_regex(infixes) + return infix_re + + def get_nlp(self): + return self._nlp diff --git a/ner_datasets/preprocessing/tokenizers/core_nlp_tokenizer.py b/ner_datasets/preprocessing/tokenizers/core_nlp_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1bb61bf8fbc5908e874f3224c043df2977da03cc --- /dev/null +++ b/ner_datasets/preprocessing/tokenizers/core_nlp_tokenizer.py @@ -0,0 +1,58 @@ +import json +from typing import Iterable, Mapping, Dict, Union + +from pycorenlp import StanfordCoreNLP + + +class CoreNLPTokenizer(object): + """ + This class is used to read text and return the tokens + present in the text (and their start and end positions) + using core nlp tokenization + """ + + def __init__(self, port: int = 9000): + """ + Initialize a core nlp server to read text and split it into + tokens using the core nlp annotators + Args: + port (int): The port to run the server on + """ + self._core_nlp = StanfordCoreNLP('http://localhost:{0}'.format(port)) + + def get_stanford_annotations(self, text: str, annotators: str = 'tokenize,ssplit,pos,lemma') -> Dict: + """ + Use the core nlp server to annotate the text and return the + results as a json object + Args: + text (str): The text to annotate + annotators (str): The core nlp annotations to run on the text + Returns: + output (Dict): The core nlp results + """ + output = self._core_nlp.annotate(text, properties={ + "timeout": "50000", + "ssplit.newlineIsSentenceBreak": "two", + 'annotators': annotators, + 'outputFormat': 'json' + }) + if type(output) is str: + output = json.loads(output, strict=False) + return output + + def get_tokens(self, text: str) -> Iterable[Dict[str, Union[str, int]]]: + """ + Return an iterable that iterates through the tokens in the text + Args: + text (str): The text to annotate + Returns: + (Iterable[Mapping[str, Union[str, int]]]): Yields a dictionary that contains the text of the token + the start position of the token in the entire text + and the end position of the token in the entire text + """ + stanford_output = self.get_stanford_annotations(text) + for sentence in stanford_output['sentences']: + for token in sentence['tokens']: + yield {'text': token['originalText'], + 'start': token['characterOffsetBegin'], + 'end': token['characterOffsetEnd']} diff --git a/ner_datasets/preprocessing/tokenizers/spacy_tokenizer.py b/ner_datasets/preprocessing/tokenizers/spacy_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7756bf2d1984fb493d6c7df58e8dc3f4c4c51bc1 --- /dev/null +++ b/ner_datasets/preprocessing/tokenizers/spacy_tokenizer.py @@ -0,0 +1,49 @@ +import spacy +from typing import Tuple, Iterable, Mapping, Dict, Union + + +class SpacyTokenizer(object): + """ + This class is used to read text and return the tokens + present in the text (and their start and end positions) + using spacy + """ + + def __init__(self, spacy_model: str): + """ + Initialize a spacy model to read text and split it into + tokens. + Args: + spacy_model (str): Name of the spacy model + """ + self._nlp = spacy.load(spacy_model) + + @staticmethod + def __get_start_and_end_offset(token: spacy.tokens.Token) -> Tuple[int, int]: + """ + Return the start position of the token in the entire text + and the end position of the token in the entire text + Args: + token (spacy.tokens.Token): The spacy token object + Returns: + start (int): the start position of the token in the entire text + end (int): the end position of the token in the entire text + """ + start = token.idx + end = start + len(token) + return start, end + + def get_tokens(self, text: str) -> Iterable[Dict[str, Union[str, int]]]: + """ + Return an iterable that iterates through the tokens in the text + Args: + text (str): The text to annotate + Returns: + (Iterable[Mapping[str, Union[str, int]]]): Yields a dictionary that contains the text of the token + the start position of the token in the entire text + and the end position of the token in the entire text + """ + document = self._nlp(text) + for token in document: + start, end = SpacyTokenizer.__get_start_and_end_offset(token) + yield {'text': token.text, 'start': start, 'end': end} diff --git a/ner_datasets/preprocessing/tokenizers/utils/__init__.py b/ner_datasets/preprocessing/tokenizers/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..65a8fd56074731b1b20c2055bc4e5a05bf03355b --- /dev/null +++ b/ner_datasets/preprocessing/tokenizers/utils/__init__.py @@ -0,0 +1,4 @@ +from .date_regex import DateRegex +from .clean_regex import CleanRegex +from .clinical_regex import ClinicalRegex +__all__=["DateRegex", "CleanRegex", "ClinicalRegex"] \ No newline at end of file diff --git a/ner_datasets/preprocessing/tokenizers/utils/clean_regex.py b/ner_datasets/preprocessing/tokenizers/utils/clean_regex.py new file mode 100644 index 0000000000000000000000000000000000000000..005c270365eebe05593efa79e63b504e3a276662 --- /dev/null +++ b/ner_datasets/preprocessing/tokenizers/utils/clean_regex.py @@ -0,0 +1,64 @@ +from typing import List +class CleanRegex(object): + """ + This class is used to define the regexes that will be used by the + spacy tokenizer rules. Mainly the regexes are used to clean up + tokens that have unwanted characters (e.g extra hyphens). + """ + #Staff - 3 + #Hosp - 4, 5 + #Loc - 2 + @staticmethod + def get_prefixes() -> List[str]: + """ + This function is used to build the regex that will clean up dirty characters + present at the prefix position (start position) of a token. For example the token ---clean + has three hyphens that need to be split from the word clean. This regex + will be used by spacy to clean it up. This rule considers any characters that is + not a letter or a digit as dirty characters + Examples: ----------------9/36, :63, -ESH + Returns: + (list): List of regexes to clean the prefix of the token + """ + #Handles case 5 of HOSP + return ['((?P([^a-zA-Z0-9.]))(?P=prefix)*)', '([.])(?!\d+(\W+|$))'] + + @staticmethod + def get_suffixes() -> List[str]: + """ + This function is used to build the regex that will clean up dirty characters + present at the suffix position (end position) of a token. For example the token clean--- + has three hyphens that need to be split from the word clean. This regex + will be used by spacy to clean it up. This rule considers any characters that is + not a letter or a digit as dirty characters + Examples: FRANK^, regimen---------------, no) + Returns: + (list): List of regexes to clean the suffix of the token + """ + return ['((?P([^a-zA-Z0-9]))(?P=suffix)*)'] + + @staticmethod + def get_infixes() -> List[str]: + """ + This function is used to build the regex that will clean up dirty characters + present at the infix position (in-between position) of a token. For example the token + clean---me has three hyphens that need to be split from the word clean and me. This regex + will be used by spacy to clean it up. This rule considers any characters that is + not a letter or a digit as dirty characters + Examples: FRANK^08/30/76^UNDERWOOD, regimen---------------1/37 + Returns: + (list): List of regexes to clean the infix of the token + """ + #Handles case 3 of STAFF + #Handles case 4 of HOSP + #Handles case 2 of LOC + connector_clean = '\^|;|&#|([\(\)\[\]:="])' + #full_stop_clean = '(?<=[a-zA-Z])(\.)(?=([A-Z][A-Za-z]+)|[^a-zA-Z0-9_.]+)' + bracket_comma_clean = '(((?<=\d)[,)(](?=[a-zA-Z]+))|((?<=[a-zA-Z])[,)(](?=\w+)))' + #special_char_clean = '(?<=[a-zA-Z])(\W{3,}|[_]{3,})(?=[A-Za-z]+)' + special_char_clean = '(?<=[a-zA-Z])([_\W_]{3,})(?=[A-Za-z]+)' + #Sometimes when there is no space between a period and a comma - it becomes part of the same token + #e.g John.,M.D - we need to split this up. + comma_period_clean = '(?<=[a-zA-Z])(\.,)(?=[A-Za-z]+)' + + return [connector_clean, bracket_comma_clean, special_char_clean, comma_period_clean] \ No newline at end of file diff --git a/ner_datasets/preprocessing/tokenizers/utils/clinical_regex.py b/ner_datasets/preprocessing/tokenizers/utils/clinical_regex.py new file mode 100644 index 0000000000000000000000000000000000000000..a7ff1a0476a2fccb53d46a10873d6210d84793ab --- /dev/null +++ b/ner_datasets/preprocessing/tokenizers/utils/clinical_regex.py @@ -0,0 +1,309 @@ +from typing import List +class ClinicalRegex(object): + """ + This class is used to define the regexes that will be used by the + spacy tokenizer rules. Mainly the regexes are used to clean up + tokens that have unwanted characters and typos (e.g missing spaces). + In the descriptions when we mention symbol we refer to any character + that is not a letter or a digit or underscore. The spacy tokenizer splits + the text by whitespace and applies these rules (along with some default rules) + to the indiviudal tokens. + """ + #Patient - 2, 3, 5 + #Staff - 1, 2 + #Hosp - 2, 3 + #Loc - 1, 3 + @staticmethod + def get_word_typo_prefix(): + """ + If token contains a typo. What we mean by a typo is when two tokens + that should be separate tokens are fused into one token because there + is a missing space. + Examples: JohnMarital Status - John is the name that is fused into the + token Marital because of a missing space. + The regex checks if we have a sequence of characters followed by another + sequence of characters that starts with a capital letter, followed by two or + more small letters, we assume this is a typo and split the tokens (two sequences) up. + If there is a symbol separating the two sequences, we ease the condition saying + the Cpaital letter can be followed by two or more capital/small letters. + Returns: + (str): regex to clean tokens that are fused because of a missing space + """ + #Handles cases 2 of PATIENT + #Handles cases 1 & 2 of STAFF + #Handles cases 2 & 3 of HOSP + #Handles cases 1 & 3 of LOC + #'(([a-z]+)|([A-Z]+)|([A-Z][a-z]+))(?=(([-./]*[A-Z][a-z]{2,})|([-./]+[A-Z][a-zA-Z]{2,})))' + return '(([a-z]+)|([A-Z]{2,})|([A-Z][a-z]+))(?=(([-./]*[A-Z][a-z]{2,})|([-./]+[A-Z][a-zA-Z]{2,})))' + + @staticmethod + def get_word_symbol_digit_prefix() -> str: + """ + If text is followed by one or more symbols and then followed by one or more digits + we make the assumption that the text is a seperate token. Spacy will use this regex + to extract the text portion as one token and will then move on to + process the rest (symbol and tokens) based on the defined rules. + Examples: Yang(4986231) - "Yang" will become a seperate token & "(4986231)" will + be processed as new token + Returns: + (str): regex to clean text followed by symbols followed by digits + """ + #Handles cases 3 & 5 of patient + return '([a-zA-Z]+)(?=\W+\d+)' + + @staticmethod + def get_multiple_prefix(split_multiple: bool) -> str: + """ + If text is of the format take it x2 times, this function + can be used to treat the entire thing as one token or + split into two seperate tokens + Args: + split_multiple (bool): whether to treat it as one token or split them up + Returns: + (str): regex to either keep as one token or split into two + """ + if(split_multiple): + return '([x])(?=(\d{1,2}$))' + else: + return '[x]\d{1,2}$' + + @staticmethod + def get_pager_prefix(): + return '([pXxPb])(?=(\d{4,}|\d+[-]\d+))' + + @staticmethod + def get_age_word_prefix(): + return '([MFmf])(?=\d{2,3}(\W+|$))' + + @staticmethod + def get_id_prefix(): + return '(ID|id|Id)(?=\d{3,})' + + @staticmethod + def get_word_period_prefix(): + return '((cf|CF|Cf|dr|DR|Dr|ft|FT|Ft|lt|LT|Lt|mr|MR|Mr|ms|MS|Ms|mt|MT|Mt|mx|MX|Mx|ph|PH|Ph|rd|RD|Rd|st|ST|St|vs|VS|Vs|wm|WM|Wm|[A-Za-z]{1})[.])(?=((\W+|$)))' + + @staticmethod + def get_chemical_prefix(): + #Vitamin B12 T9 or maybe codes like I48.9- should probaly do \d{1,2} - limit arbitary numbers + """ + There are certain chemicals, vitamins etc that should not be split. They + should be kept as a single token - for example the token "B12" in + "Vitamin B12". This regex checks if there is a single capital letter + followed by some digits (there can be a hyphen in between those digits) + then this most likely represents a token that should not be split + Returns: + (str): regex to keep vitamin/chemical names as a single token + """ + #return '((\d)?[A-EG-LN-OQ-WYZ]{1}\d+([.]\d+)?(-\d{1,2})*)(?=(([\(\)\[\]:="])|\W*$))' + return '((\d)?[A-EG-LN-OQ-WYZ]{1}\d+([.]\d+)?(-\d+)*)(?=(([\(\)\[\]:="])|\W*$))' + + @staticmethod + def get_chemical_prefix_small(): + #Vitamin B12 T9 or maybe codes like I48.9- should probaly do \d{1,2} - limit arbitary numbers + """ + There are certain chemicals, vitamins etc that should not be split. They + should be kept as a single token - for example the token "B12" in + "Vitamin B12". This regex checks if there is a single capital letter + followed by some digits (there can be a hyphen in between those digits) + then this most likely represents a token that should not be split + Returns: + (str): regex to keep vitamin/chemical names as a single token + """ + #return '((\d)?[A-EG-LN-OQ-WYZ]{1}\d+([.]\d+)?(-\d{1,2})*)(?=(([\(\)\[\]:="])|\W*$))' + return '((\d)?[a-eg-ln-oq-wyz]{1}\d+([.]\d+)?(-\d+)*)(?=(([\(\)\[\]:="])|\W*$))' + + @staticmethod + def get_instrument_prefix(): + """ + There are cases when there are tokens like L1-L2-L3, we want to keep these as one + single token. This regex checks if there is a capital letter + Returns: + (str): regex to keep vitamin/chemical names as a single token + """ + return '([A-Z]{1,2}\d+(?P[-:]+)[A-Z]{1,2}\d+((?P=instrument)[A-Z]{1,2}\d+)*)' + + @staticmethod + def get_instrument_prefix_small(): + """ + There are cases when there are tokens like L1-L2-L3, we want to keep these as one + single token. This regex checks if there is a capital letter + Returns: + (str): regex to keep vitamin/chemical names as a single token + """ + return '([a-z]{1,2}\d+(?P[-:]+)[a-z]{1,2}\d+((?P=instrument_small)[a-z]{1,2}\d+)*)' + + #Handles Case 3, 4, 5 of MRN + #Handles Case 1, 2, 3 of PHONE + #Handles Case 7, 10 of AGE + #Handles Case 1 of IDNUM + #Handles Case 3, 5 of PATIENT + #Handles Case 7 of HOSP + #Handles Case 1 of General + @staticmethod + def get_age_typo_prefix(): + """ + There are cases when there is no space between the text and the age + Example: Plan88yo - we want Plan to be a seperate token + Returns: + (str): + """ + age_suffix = '(([yY][eE][aA][rR]|[yY][oO]' + \ + '|[yY][rR]|[yY]\.[oO]|[yY]/[oO]|[fF]|[mM]|[yY])' + \ + '(-)*([o|O][l|L][d|D]|[f|F]|[m|M]|[o|O])?)' + return '([a-zA-Z]+)(?=((\d{1,3})' + age_suffix + '$))' + + @staticmethod + def get_word_digit_split_prefix(): + #Word followed by more than 3 digits - might not be part of the same token + #and could be a typo + #This need not be true - maybe we have an id like BFPI980801 - this will be split + #BFPI 980801 - but it might be okay to split - need to check + #([A-Z][a-z]{2,})(?=\d+) + return '([A-Z][a-z]{2,})(?=[A-Za-z]*\d+)' + + @staticmethod + def get_word_digit_mix_prefix(): + #Mix of letters and characters - most likely a typo if the + #following characters is a capital letter followed by more than + #2 small letters + #return '([A-Z]+\d+([A-Z]+(?!([a-z]{2,}))))(?=(\W+|([A-Z][a-z]{2,})|[a-z]{3,}))' + return '([A-Z]+\d+)(?=(\W+|([A-Z][a-z]{2,})|[a-z]{3,}))' + + @staticmethod + def get_word_digit_mix_prefix_small(): + #Mix of letters and characters - most likely a typo if the + #following characters is a capital letter followed by more than + #2 small letters + return '([a-z]+\d+)(?=(\W+|[A-Z][a-z]{2,}|[A-Z]{3,}))' + + @staticmethod + def get_word_id_split_prefix(): + return '([a-zA-Z]+)(?=(\d+[-./]+(\d+|$)))' + + @staticmethod + def get_word_section_prefix(): + #Fix JOHNID/CC - missing space from previous section - JOHN + return '([A-Za-z]+)(?=(((?P[/:]+)[A-Za-z]+)((?P=slash)[A-Za-z]+)*\W+\d+))' + + @staticmethod + def get_colon_prefix(): + #Split tokens before and after the token + #Does not split time - we make sure the token ebfore the colon + #starts with a letter. + #Splits patterns like : where CHAR 1 starts with a + #letter and is followed by one more letters/digits + #CHAR 2 is a combination of letters/digits of length greater than 2 + #This wont split time, but assumes that when the colon is present + #the entities on either side of the token are different tokens + #A:9 - not split - more likely this makes sense as a single token (could be a chemical) + return '([A-Za-z][A-Za-z0-9]+)(?=([:][A-Za-z0-9]{2,}))' + + @staticmethod + def get_temperature_prefix(split_temperature): + if(split_temperature): + return '((\d+)|(\d+[.]\d+))(?=(\u00B0([FCK]{1}|$)))' + else: + return '(((\d+)|(\d+[.]\d+))\u00B0([FCK]{1}|$))|(\u00A9[FCK]{1})' + + @staticmethod + def get_percentage_prefix(split_percentage): + """ + If text is of the format take it 20% times, this function + can be used to treat the entire thing as one token or + split into two seperate tokens + Args: + split_percentage (bool): whether to treat it as one token or split them up + Returns: + (str): regex to either keep as one token or split into two + """ + if(split_percentage): + return '(((\d+)|(\d+[.]\d+)))(?=(%(\W+|$)))' + else: + return '(((\d+)|(\d+[.]\d+))%(\W+|$))' + + @staticmethod + def get_value_range_prefixes(): + #The following regex might not work on .4-.5 - no number before decimal point + #need to figure this out without breaking anything else + value_range_1 = '(\d{1})(?=([-]((\d{1,2}|(\d+)[.](\d+)))([a-zA-Z]+|[\W]*$)))' + value_range_2 = '(\d{2})(?=([-]((\d{2,3}|(\d+)[.](\d+)))([a-zA-Z]+|[\W]*$)))' + value_range_3 = '(\d{3})(?=([-]((\d{3}|(\d+)[.](\d+)))([a-zA-Z]+|[\W]*$)))' + return value_range_1, value_range_2, value_range_3 + + @staticmethod + def get_year_range_prefix(): + return '(\d{4})(?=([-](\d{4})([a-zA-Z]+|[\W]*$)))' + + @staticmethod + def get_short_digit_id_prefix(): + #4A, 3C etc + return '(\d{1,2}[A-EG-LN-WZ]{1}(?=(\W+|$)))' + + #Handles Case 1, 2 of MRN + #Handles Case 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, 16, 17, 18, 19 of AGE + #Handles Case 2, 3, 5 of IDNUM + #Handles Case 1 of HOSP + @staticmethod + def get_digit_symbol_word_prefix(): + return '((\d+)|(\d+[.]\d+))(?=\W+[a-zA-Z]+)' + + @staticmethod + def get_digit_age_split_prefix(): + age_suffix = '(([yY][eE][aA][rR]|[yY][oO]' + \ + '|[yY][rR]|[yY]\.[oO]|[yY]/[oO]|[fF]|[mM]|[yY])' + \ + '(-)*([o|O][l|L][d|D]|[f|F]|[m|M]|[o|O])?)' + return '((\d{1,3}))(?=(' + age_suffix + '\W*$))' + + @staticmethod + def get_digit_word_short_prefix(): + return '((\d+)|(\d+[.]\d+))([a-z]{1,2}|[A-Z]{1,2})(?=(\W*$))' + + @staticmethod + def get_digit_word_typo_prefix(): + return '((\d+)|(\d+[.]\d+))(?=[a-zA-Z]{1}[a-zA-Z\W]+)' + + @staticmethod + def get_prefixes(split_multiple, split_temperature, split_percentage): + word_typo_prefix = ClinicalRegex.get_word_typo_prefix() + word_symbol_digit_prefix = ClinicalRegex.get_word_symbol_digit_prefix() + pager_prefix = ClinicalRegex.get_pager_prefix() + age_word_prefix = ClinicalRegex.get_age_word_prefix() + word_period_prefix = ClinicalRegex.get_word_period_prefix() + id_prefix = ClinicalRegex.get_id_prefix() + multiple_prefix = ClinicalRegex.get_multiple_prefix(split_multiple) + chemical_prefix = ClinicalRegex.get_chemical_prefix() + chemical_prefix_small = ClinicalRegex.get_chemical_prefix_small() + instrument_prefix = ClinicalRegex.get_instrument_prefix() + instrument_prefix_small = ClinicalRegex.get_instrument_prefix_small() + age_typo_prefix = ClinicalRegex.get_age_typo_prefix() + word_digit_split_prefix = ClinicalRegex.get_word_digit_split_prefix() + word_digit_mix_prefix = ClinicalRegex.get_word_digit_mix_prefix() + word_digit_mix_prefix_small = ClinicalRegex.get_word_digit_mix_prefix_small() + word_id_split_prefix = ClinicalRegex.get_word_id_split_prefix() + word_section_prefix = ClinicalRegex.get_word_section_prefix() + colon_prefix = ClinicalRegex.get_colon_prefix() + temperature_prefix = ClinicalRegex.get_temperature_prefix(split_temperature) + percentage_prefix = ClinicalRegex.get_percentage_prefix(split_percentage) + value_range_1, value_range_2, value_range_3 = ClinicalRegex.get_value_range_prefixes() + year_range_prefix = ClinicalRegex.get_year_range_prefix() + short_digit_id_prefix = ClinicalRegex.get_short_digit_id_prefix() + digit_symbol_word_prefix = ClinicalRegex.get_digit_symbol_word_prefix() + digit_age_split_prefix = ClinicalRegex.get_digit_age_split_prefix() + digit_word_short_prefix = ClinicalRegex.get_digit_word_short_prefix() + digit_word_typo_prefix = ClinicalRegex.get_digit_word_typo_prefix() + + return [word_typo_prefix, word_symbol_digit_prefix, pager_prefix, age_word_prefix,\ + word_period_prefix, id_prefix, multiple_prefix, chemical_prefix, chemical_prefix_small,\ + instrument_prefix, instrument_prefix_small, age_typo_prefix, word_digit_split_prefix,\ + word_id_split_prefix, word_digit_mix_prefix, word_digit_mix_prefix_small, \ + word_section_prefix, colon_prefix, temperature_prefix,\ + percentage_prefix, value_range_1, value_range_2, value_range_3, year_range_prefix,\ + short_digit_id_prefix, digit_symbol_word_prefix, digit_age_split_prefix,\ + digit_word_short_prefix, digit_word_typo_prefix] + + @staticmethod + def get_infixes(): + digit_infix = '(\d+(?P[-:]+)\d+((?P=sep)\d+)*)' + return [digit_infix, ] + \ No newline at end of file diff --git a/ner_datasets/preprocessing/tokenizers/utils/date_regex.py b/ner_datasets/preprocessing/tokenizers/utils/date_regex.py new file mode 100644 index 0000000000000000000000000000000000000000..c8f6929a386b2160695ee183a9eb2e509bb69825 --- /dev/null +++ b/ner_datasets/preprocessing/tokenizers/utils/date_regex.py @@ -0,0 +1,104 @@ +class DateRegex(object): + + @staticmethod + def __get_day_attributes(): + # day of the month with optional suffix, such as 7th, 22nd + dd = '(([0-2]?[0-9]|3[01])(\s*)([sS][tT]|[nN][dD]|[rR][dD]|[tT][hH])?)' + # two-digit numeric day of the month + DD = '(0[0-9]|[1-2][0-9]|3[01])' + + return dd, DD + + @staticmethod + def __get_month_attributes(): + + m = \ + '([jJ][aA][nN]([uU][aA][rR][yY])?|'+\ + '[fF][eE][bB]([rR][uU][aA][rR][yY])?|'+\ + '[mM][aA][rR]([cC][hH])?|'+\ + '[aA][pP][rR]([iI][lL])?|'+\ + '[mM][aA][yY]|'+\ + '[jJ][uU][nN]([eE])?|'+\ + '[jJ][uU][lL]([yY])?|'+\ + '[aA][uU][gG]([uU][sS][tT])?|'+\ + '[sS][eE][pP]([tT][eE][mM][bB][eE][rR])?|'+\ + '[oO][cC][tT]([oO][bB][eE][rR])?|'+\ + '[nN][oO][vV]([eE][mM][bB][eE][rR])?|'+\ + '[dD][eE][cC]([eE][mM][bB][eE][rR])?)' + M = m + + # numeric month + mm = '(0?[0-9]|1[0-2]|' + m + ')' + + # two digit month + MM = '(0[0-9]|1[0-2]|' + m + ')' + + return m, M, mm, MM + + @staticmethod + def __get_year_attributes(): + + # two or four digit year + y = '([0-9]{4}|[0-9]{2})' + + # two digit year + yy = '([0-9]{2})' + + # four digit year + YY = '([0-9]{4})' + + return y, yy, YY + + @staticmethod + def __get_sep_attributes(): + + date_sep = '[-./]' + date_sep_optional = '[-./]*' + date_sep_no_full = '[-/]' + + return date_sep, date_sep_optional, date_sep_no_full + + #def get_week_attributes(): + # w = \ + # '([mM][oO][nN]([dD][aA][yY])?|'+\ + # '[tT][uU][eE]([sS][dD][aA][yY])?|'+\ + # '[wW][eE][dD]([nN][eE][sS][dD][aA][yY])?|'+\ + # '[tT][hH][uU][gG]([uU][sS][tT])?|'+\ + # '[sS][eE][pP]([tT][eE][mM][bB][eE][rR])?|'+\ + # '[oO][cC][tT]([oO][bB][eE][rR])?|'+\ + # '[nN][oO][vV]([eE][mM][bB][eE][rR])?|'+\ + # '[dD][eE][cC]([eE][mM][bB][eE][rR])?)' + + @staticmethod + def get_infixes(): + + dd, DD = DateRegex.__get_day_attributes() + m, M, mm, MM = DateRegex.__get_month_attributes() + y, yy, YY = DateRegex.__get_year_attributes() + date_sep, date_sep_optional, date_sep_no_full = DateRegex.__get_sep_attributes() + + date_1 = y + '/' + mm + '/' + dd + '(?!([/]+|\d+))' + date_2 = y + '/' + dd + '/' + mm + '(?!([/]+|\d+))' + date_3 = dd + '/' + mm + '/' + y + '(?!([/]+|\d+))' + date_4 = mm + '/' + dd + '/' + y + '(?!([/]+|\d+))' + #Do I make this optional (date_sep_optional) - need to check + date_5 = y + date_sep + m + date_sep + dd + '(?!\d)' + date_6 = y + date_sep + dd + date_sep + m + date_7 = dd + date_sep + m + date_sep + y + date_8 = m + date_sep + dd + date_sep + y + date_9 = y + date_sep + m + date_10 = m + date_sep + y + date_11 = dd + date_sep + m + date_12 = m + date_sep + dd + date_13 = '(? So now AGE Span is 75yo instead of 75. This script essentially changes + the annotated spans to match the tokens. In an ideal case we wouldn't need this script + but since medical notes have many typos, this script becomes necessary to deal with + issues and changes that arise from different tokenizers. + Also sort the spans and convert the start and end keys of the spans to integers + """ + + def __init__( + self, + sentencizer: str, + tokenizer: str, + ner_priorities: Mapping[str, int], + verbose: bool = True + ) -> NoReturn: + """ + Initialize the sentencizer and tokenizer + Args: + sentencizer (str): The sentencizer to use for splitting text into sentences + tokenizer (str): The tokenizer to use for splitting text into tokens + ner_priorities (Mapping[str, int]): The priority when choosing which duplicates to remove. + Mapping that represents a priority for each NER type + verbose (bool): To print out warnings etc + """ + self._sentencizer = PreprocessingLoader.get_sentencizer(sentencizer) + self._tokenizer = PreprocessingLoader.get_tokenizer(tokenizer) + self._ner_priorities = ner_priorities + self._verbose = verbose + + def __get_token_positions(self, text: str) -> Tuple[Dict[int, int], Dict[int, int]]: + """ + Get the start and end positions of all the tokens in the note. + Args: + text (str): The text present in the note + Returns: + token_start_positions (Mapping[int, int]): The start positions of all the tokens in the note + token_end_positions (Mapping[int, int]): The end positions of all the tokens in the note + """ + token_start_positions = dict() + token_end_positions = dict() + for sentence in self._sentencizer.get_sentences(text): + offset = sentence['start'] + for token in self._tokenizer.get_tokens(sentence['text']): + start = token['start'] + offset + end = token['end'] + offset + token_start_positions[start] = 1 + token_end_positions[end] = 1 + return token_start_positions, token_end_positions + + def get_duplicates( + self, + spans: List[Dict[str, Union[str, int]]], + ) -> List[int]: + """ + Return the indexes where there are duplicate/overlapping spans. A duplicate or + span is one where the same token can have two labels. + E.g: + Token: BWH^Bruce + This is a single token where BWH is the hospital label and Bruce is the Patient label + The fix_alignment function assigns this entre token the hospital label but it also + assigns this entire token the patient label. Since we have two labels for the same + token, we need to remove one of them. + We assign this entire token one label - either hospital label or the patient label + In this case we assign patient because of higher priority. So now we need to remove + the hospital label from the dataset (since it is essentially a duplicate label). This + script handles this case. + There are cases when two different labels match the same token partially + E.g + Text: JT/781-815-9090 + Spans: JT - hospital, 781-815-9090 - Phone + Tokens: (Jt/781) & (- 815 - 9090) + As you can see the token JT/781 will be assigned the label in the fix_alignment function + but 781-815-9090 is also phone and the 781 portion is overlapped, and we need to resolve this. + In this script, we resolve it by treating JT/781 as one span (hospital) and + -815-9090 as another span (phone). + Args: + spans ([List[Dict[str, Union[str, int]]]): The NER spans in the note + Returns: + remove_spans (Sequence[int]): A list of indexes of the spans to remove + """ + remove_spans = list() + prev_start = -1 + prev_end = -1 + prev_label = None + prev_index = None + spans.sort(key=lambda _span: (_span['start'], _span['end'])) + for index, span in enumerate(spans): + current_start = span['start'] + current_end = span['end'] + current_label = span['label'] + if type(current_start) != int or type(current_end) != int: + raise ValueError('The start and end keys of the span must be of type int') + # Check if the current span matches another span + # that is if this span covers the same tokens as the + # previous spans (but has a different label) + # Based on the priority, treat the span with the low + # priority label as a duplicate label and add it to the + # list of spans that need to be removed + if current_start == prev_start and current_end == prev_end: + if self._ner_priorities[current_label] > self._ner_priorities[prev_label]: + # Store index of the previous span if it has lower priority + remove_spans.append(prev_index) + # Reset span details + prev_start = current_start + prev_end = current_end + prev_index = index + prev_label = current_label + if self._verbose: + print('DUPLICATE: ', span) + print('REMOVED: ', spans[remove_spans[-1]]) + elif self._ner_priorities[current_label] <= self._ner_priorities[prev_label]: + # Store current index of span if it has lower priority + remove_spans.append(index) + if self._verbose: + print('DUPLICATE: ', spans[prev_index]) + print('REMOVED: ', spans[remove_spans[-1]]) + # Check for overlapping span + elif current_start < prev_end: + # If the current span end matches the overlapping span end + # Remove the current span, since it is smaller + if current_end <= prev_end: + remove_spans.append(index) + if self._verbose: + print('DUPLICATE: ', spans[prev_index]) + print('REMOVED: ', spans[remove_spans[-1]]) + # If the current end is greater than the prev_end + # then we split it into tow spans. We treat the previous span + # as one span and the end of the previous span to the end of the current span + # as another span. + elif current_end > prev_end: + # Create the new span - start=previous_span_end, end=current_span_end + overlap_length = spans[prev_index]['end'] - current_start + new_text = span['text'][overlap_length:] + # Remove extra spaces that may arise during this span separation + new_text = re.sub('^(\s+)', '', new_text, flags=re.DOTALL) + span['start'] = current_end - len(new_text) + span['text'] = new_text + if self._verbose: + print('OVERLAP: ', spans[prev_index]) + print('UPDATED: ', span) + # Reset span details + prev_start = current_start + prev_end = current_end + prev_label = current_label + prev_index = index + # Reset span details + else: + prev_start = current_start + prev_end = current_end + prev_label = current_label + prev_index = index + return remove_spans + + def fix_alignment( + self, + text: str, + spans: Sequence[Dict[str, Union[str, int]]] + ) -> Iterable[Dict[str, Union[str, int]]]: + """ + Align the span and tokens. When the tokens and spans don't align, we change the + start and end positions of the spans so that they align with the tokens. This is + needed when a different tokenizer is used and the spans which are defined against + a different tokenizer don't line up with the new tokenizer. Also remove spaces present + at the start or end of the span. + E.g: + Token: BWH^Bruce + This is a single token where BWH is the hospital label and Bruce is the Patient label + The fix_alignment function assigns this entre token the hospital label but it also + assigns this entire token the patient label. This function basically expands the span + so that it matches the start and end positions of some token. By doing this it may create + overlapping and duplicate spans. As you can see it expands the patient label to match the + start of the token and it expands the hospital label to match the end of the token. + function. + Args: + text (str): The text present in the note + spans ([Sequence[Dict[str, Union[str, int]]]): The NER spans in the note + Returns: + (Iterable[Dict[str, Union[str, int]]]): Iterable through the modified spans + """ + # Get token start and end positions so that we can check if a span + # coincides with the start and end position of some token. + token_start_positions, token_end_positions = self.__get_token_positions(text) + for span in spans: + start = span['start'] + end = span['end'] + if type(start) != int or type(end) != int: + raise ValueError('The start and end keys of the span must be of type int') + if re.search('^\s', text[start:end]): + if self._verbose: + print('WARNING - space present in the start of the span') + start = start + 1 + if re.search('(\s+)$', text[start:end], flags=re.DOTALL): + new_text = re.sub('(\s+)$', '', text[start:end], flags=re.DOTALL) + end = start + len(new_text) + # When a span does not coincide with the start and end position of some token + # it means there will be an error when building the ner dataset, we try and avoid + # that error by updating the spans itself, that is we expand the start/end positions + # of the spans so that it is aligned with the tokens. + while token_start_positions.get(start, False) is False: + start -= 1 + while token_end_positions.get(end, False) is False: + end += 1 + # Print what the old span was and what the new expanded span will look like + if self._verbose and (int(span['start']) != start or int(span['end']) != end): + print('OLD SPAN: ', text[int(span['start']):int(span['end'])]) + print('NEW SPAN: ', text[start:end]) + # Update the span with its new start and end positions + span['start'] = start + span['end'] = end + span['text'] = text[start:end] + yield span + + def fix_note( + self, + text: str, + spans: Sequence[Dict[str, Union[str, int]]], + ) -> Iterable[Dict[str, Union[str, int]]]: + """ + This function changes the span_start and span_end + so that the span_start will coincide with some token_start and the span_end + will coincide with some token_end and also removes duplicate/overlapping spans + that may arise when we change the span start and end positions. The resulting + spans from this function will always coincide with some token start and token + end, and hence will not have any token and span mismatch errors when building the + NER dataset. For more details and examples check the documentation of the + fix_alignment and get_duplicates functions. + Args: + text (str): The text present in the note + spans ([Sequence[Mapping[str, Union[str, int]]]): The NER spans in the note + Returns: + (Iterable[Mapping[str, Union[str, int]]]): Iterable through the fixed spans + """ + # Fix span position alignment + spans = [span for span in self.fix_alignment(text=text, spans=spans)] + # Check for duplicate/overlapping spans + remove_spans = self.get_duplicates(spans=spans) + for index, span in enumerate(spans): + # Remove the duplicate/overlapping spans + if index not in remove_spans: + yield span + + def fix( + self, + input_file: str, + text_key: str = 'text', + spans_key: str = 'spans' + ) -> Iterable[Dict[str, Union[str, Dict[str, str], List[Dict[str, str]]]]]: + """ + This function changes the span_start and span_end + so that the span_start will coincide with some token_start and the span_end + will coincide with some token_end and also removes duplicate/overlapping spans + that may arise when we change the span start and end positions. The resulting + spans from this function will always coincide with some token start and token + end, and hence will not have any token and span mismatch errors when building the + NER dataset. For more details and examples check the documentation of the + fix_alignment and get_duplicates functions. Fix spans that arise due to bad typos, + which are not fixed during tokenization. This essentially updates the spans so that + they line up with the start and end positions of tokens - so that there is no error + when we assign labels to tokens based on these spans + Args: + input_file (str): The file that contains the notes that we want to fix the token issues in + text_key (str) the key where the note & token text is present in the json object + spans_key (str): The key where the note spans are present in the json object + Returns: + (Iterable[Dict[str, Union[str, Dict[str, str], List[Dict[str, str]]]]]): Iterable through the fixed + notes + """ + for line in open(input_file, 'r'): + note = json.loads(line) + note[spans_key] = [span for span in self.fix_note(text=note[text_key], spans=note[spans_key])] + yield note + + +def main(): + # The following code sets up the arguments to be passed via CLI or via a JSON file + cli_parser = ArgumentParser( + description='configuration arguments provided at run time from the CLI', + formatter_class=ArgumentDefaultsHelpFormatter + ) + cli_parser.add_argument( + '--input_file', + type=str, + required=True, + help='the the jsonl file that contains the notes' + ) + cli_parser.add_argument( + '--sentencizer', + type=str, + required=True, + help='the sentencizer to use for splitting notes into sentences' + ) + cli_parser.add_argument( + '--tokenizer', + type=str, + required=True, + help='the tokenizer to use for splitting text into tokens' + ) + cli_parser.add_argument( + '--abbreviations_file', + type=str, + default=None, + help='file that will be used by clinical tokenizer to handle abbreviations' + ) + cli_parser.add_argument( + '--ner_types', + nargs="+", + require=True, + help='the NER types' + ) + cli_parser.add_argument( + '--ner_priorities', + nargs="+", + require=True, + help='the priorities for the NER types - the priority when choosing which duplicates to remove' + ) + cli_parser.add_argument( + '--text_key', + type=str, + default='text', + help='the key where the note & token text is present in the json object' + ) + cli_parser.add_argument( + '--spans_key', + type=str, + default='spans', + help='the key where the note spans is present in the json object' + ) + cli_parser.add_argument( + '--output_file', + type=str, + required=True, + help='the output json file that will contain the new fixed spans' + ) + args = cli_parser.parse_args() + # Mapping that represents a priority for each PHI type + # For example, the PATIENT type will have a higher priority as + # compared to STAFF. + if len(args.ner_types) == len(args.ner_priorities): + ner_priorities = {ner_type: priority for ner_type, priority in zip(args.ner_types, args.ner_priorities)} + else: + raise ValueError('Length of ner_types and ner_priorities must be the same') + span_fixer = SpanFixer( + tokenizer=args.tokenizer, + sentencizer=args.sentencizer, + ner_priorities=ner_priorities + ) + with open(args.output_file, 'w') as file: + for note in span_fixer.fix( + input_file=args.input_file, + text_key=args.text_key, + spans_key=args.spans_key + ): + file.write(json.dumps(note) + '\n') + + +if __name__ == '__main__': + main() diff --git a/ner_datasets/span_validation.py b/ner_datasets/span_validation.py new file mode 100644 index 0000000000000000000000000000000000000000..760b404b8be5d6ae78173e07a752b5dad5be08b1 --- /dev/null +++ b/ner_datasets/span_validation.py @@ -0,0 +1,91 @@ +import json +import random +from argparse import ArgumentParser +from typing import Union, NoReturn, Iterable, Dict, List + +random.seed(41) + + +class SpanValidation(object): + """ + This class is used to build a mapping between the note id + and the annotated spans in that note. This will be used during the + evaluation of the models. This is required to perform span level + evaluation. + """ + @staticmethod + def get_spans( + input_file: str, + metadata_key: str = 'meta', + note_id_key: str = 'note_id', + spans_key: str = 'spans' + ): + """ + Get a mapping between the note id + and the annotated spans in that note. This will mainly be used during the + evaluation of the models. + Args: + input_file (str): The input file + metadata_key (str): The key where the note metadata is present + note_id_key (str): The key where the note id is present + spans_key (str): The key that contains the annotated spans for a note dictionary + Returns: + (Iterable[Dict[str, Union[str, List[Dict[str, str]]]]]): An iterable that iterates through each note + and contains the note id and annotated spans + for that note + """ + # Read the input files (data source) + for line in open(input_file, 'r'): + note = json.loads(line) + note_id = note[metadata_key][note_id_key] + # Store the note_id and the annotated spans + note[spans_key].sort(key=lambda _span: (_span['start'], _span['end'])) + yield {'note_id': note_id, 'note_spans': note[spans_key]} + + +def main() -> NoReturn: + cli_parser = ArgumentParser(description='configuration arguments provided at run time from the CLI') + cli_parser.add_argument( + '--input_file', + type=str, + required=True, + help='the the jsonl file that contains the notes' + ) + cli_parser.add_argument( + '--metadata_key', + type=str, + default='meta', + help='the key where the note metadata is present in the json object' + ) + cli_parser.add_argument( + '--note_id_key', + type=str, + default='note_id', + help='the key where the note id is present in the json object' + ) + cli_parser.add_argument( + '--span_text_key', + type=str, + default='spans', + help='the key where the annotated spans for the notes are present in the json object' + ) + cli_parser.add_argument( + '--output_file', + type=str, + required=True, + help='the file where the note id and the corresponding spans for that note are to be saved' + ) + args = cli_parser.parse_args() + + # Write the dataset to the output file + with open(args.output_file, 'w') as file: + for span_info in SpanValidation.get_spans( + input_file=args.input_file, + metadata_key=args.metadata_key, + note_id_key=args.note_id_key, + spans_key=args.spans_key): + file.write(json.dumps(span_info) + '\n') + + +if __name__ == "__main__": + main() diff --git a/sequence_tagging/.DS_Store b/sequence_tagging/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..c5339b6afa3f0b805f6ede10fce2bbafab01138d Binary files /dev/null and b/sequence_tagging/.DS_Store differ diff --git a/sequence_tagging/__init__.py b/sequence_tagging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a359b711ce2bcfee79eee67bdffe381aa1568b20 --- /dev/null +++ b/sequence_tagging/__init__.py @@ -0,0 +1,2 @@ +from .sequence_tagger import SequenceTagger +__all__ = ["SequenceTagger"] diff --git a/sequence_tagging/__pycache__/__init__.cpython-37.pyc b/sequence_tagging/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef64b9f986e99db20b208542f7fbcad350aeaea3 Binary files /dev/null and b/sequence_tagging/__pycache__/__init__.cpython-37.pyc differ diff --git a/sequence_tagging/__pycache__/sequence_tagger.cpython-37.pyc b/sequence_tagging/__pycache__/sequence_tagger.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14848a957f5eb9d8465abc2ac7be4de0151d1354 Binary files /dev/null and b/sequence_tagging/__pycache__/sequence_tagger.cpython-37.pyc differ diff --git a/sequence_tagging/arguments/__init__.py b/sequence_tagging/arguments/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..58f06da514b8380b3688c6775a8c243917c861f9 --- /dev/null +++ b/sequence_tagging/arguments/__init__.py @@ -0,0 +1,8 @@ +from .model_arguments import ModelArguments +from .evaluation_arguments import EvaluationArguments +from .data_training_arguments import DataTrainingArguments +__all__ = [ + "ModelArguments", + "DataTrainingArguments", + "EvaluationArguments", +] diff --git a/sequence_tagging/arguments/data_training_arguments.py b/sequence_tagging/arguments/data_training_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..e7d34f1e3a1b7dc95134d211f3037506c767a267 --- /dev/null +++ b/sequence_tagging/arguments/data_training_arguments.py @@ -0,0 +1,115 @@ +from typing import Optional +from dataclasses import dataclass, field + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + task_name: Optional[str] = field( + default="ner", + metadata={"help": "The name of the task (ner, pos...)."} + ) + notation: str = field( + default="BIO", + metadata={"help": "NER notation e.g BIO"}, + ) + ner_types: Optional[str] = field( + default=None, + metadata={"help": "Pass a list of NER types"}, + ) + train_file: Optional[str] = field( + default=None, + metadata={"help": "The input training data file (a csv or JSON file)."} + ) + validation_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."}, + ) + test_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."}, + ) + output_predictions_file: Optional[str] = field( + default=None, + metadata={"help": "A location where to write the output of the test data"}, + ) + text_column_name: Optional[str] = field( + default='tokens', + metadata={"help": "The column name of text to input in the file (a csv or JSON file)."} + ) + label_column_name: Optional[str] = field( + default='labels', + metadata={"help": "The column name of label to input in the file (a csv or JSON file)."} + ) + overwrite_cache: bool = field( + default=False, + metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": "Whether to pad all samples to model maximum sentence length. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " + "efficient on GPU but very bad for TPU." + }, + ) + truncation: bool = field( + default=True, + metadata={ + "help": "Activates and controls truncation" + }, + ) + max_length: int = field( + default=512, + metadata={ + "help": "Controls the maximum length to use by one of the truncation/padding parameters." + }, + ) + do_lower_case: bool = field( + default=False, + metadata={ + "help": "Whether to lowercase the text" + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + }, + ) + label_all_tokens: bool = field( + default=False, + metadata={ + "help": "Whether to put the label for one word on all tokens of generated by that word or just on the " + "one (in which case the other tokens will have a padding index)." + }, + ) + return_entity_level_metrics: bool = field( + default=True, + metadata={"help": "Whether to return all the entity levels during evaluation or just the overall ones."}, + ) + token_ignore_label: str = field( + default='NA', + metadata={"help": "The label that indicates where the tokens will be ignored in loss computation. Used for " + "indicating context tokens to the model"} + ) \ No newline at end of file diff --git a/sequence_tagging/arguments/evaluation_arguments.py b/sequence_tagging/arguments/evaluation_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..60cd200136102195b1058bb0a46517f78b96db0f --- /dev/null +++ b/sequence_tagging/arguments/evaluation_arguments.py @@ -0,0 +1,26 @@ +from typing import Optional +from dataclasses import dataclass, field + +@dataclass +class EvaluationArguments: + """ + Arguments pertaining to the evaluation process. + """ + model_eval_script: Optional[str] = field( + default=None, + metadata={"help": "The script that is used for evaluation"}, + ) + evaluation_mode: Optional[str] = field( + default=None, + metadata={"help": "Strict or default mode for sequence evaluation"}, + ) + validation_spans_file: Optional[str] = field( + default=None, + metadata={"help": "A span evaluation data file to evaluate on span level (json file). This will contain a " + "mapping between the note_ids and note spans"}, + ) + ner_type_maps: Optional[str] = field( + default=None, + metadata={"help": "List that contains the mappings between the original NER types to another set of NER " + "types. Used mainly for evaluation. to map ner token labels to another set of ner token"}, + ) \ No newline at end of file diff --git a/sequence_tagging/arguments/model_arguments.py b/sequence_tagging/arguments/model_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..39d1c44db25ef450361d41e872bbfa6f9809a234 --- /dev/null +++ b/sequence_tagging/arguments/model_arguments.py @@ -0,0 +1,43 @@ +from typing import Optional +from dataclasses import dataclass, field + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, + metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, + metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + post_process: str = field( + default='argmax', + metadata={"help": "What post processing to use on the model logits"}, + ) + threshold: Optional[float] = field( + default=None, + metadata={"help": "Threshold cutoff for softmax"}, + ) diff --git a/sequence_tagging/dataset_builder/__init__.py b/sequence_tagging/dataset_builder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7260dbc831f7d79b0b85a4ce2ff386597e672d5 --- /dev/null +++ b/sequence_tagging/dataset_builder/__init__.py @@ -0,0 +1,5 @@ +from .ner_labels import NERLabels +from .ner_dataset import NERDataset +from .label_mapper import LabelMapper +from .dataset_tokenizer import DatasetTokenizer +__all__=["NERLabels", "NERDataset", "LabelMapper", "DatasetTokenizer"] \ No newline at end of file diff --git a/sequence_tagging/dataset_builder/dataset_tokenizer.py b/sequence_tagging/dataset_builder/dataset_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5fffcb6f929a153ac3b4ddb46d7108950845ca7e --- /dev/null +++ b/sequence_tagging/dataset_builder/dataset_tokenizer.py @@ -0,0 +1,178 @@ +from typing import Mapping, Sequence, List, Union, Optional, NoReturn +from datasets import Dataset +from transformers import PreTrainedTokenizerFast, PreTrainedTokenizer + + +class DatasetTokenizer(object): + """ + The main goal of this class is to solve the problem described below. + Most of the comments have been copied from the huggingface webpage. + What this class does is initialize a tokenizer with the desired parameters + and then tokenize our dataset and align the tokens with the labels + while keeping in mind the problem & solution described below. We can use this + function to train and for predictions - we just assume the predictions dataset + will have a label column filled with some values (so this code can be re-used). + Now we arrive at a common obstacle with using pre-trained models for + token-level classification: many of the tokens in the dataset may not + be in the tokenizer vocabulary. Bert and many models like it use a method + called WordPiece Tokenization, meaning that single words are split into multiple + tokens such that each token is likely to be in the vocabulary. For example, + the tokenizer would split the date (token) 2080 into the tokens ['208', '##0']. + This is a problem for us because we have exactly one tag per token (2080 -> B-DATE). + If the tokenizer splits a token into multiple sub-tokens, then we will end up with + a mismatch between our tokens and our labels (208, 0) - two tokens but one label (B-DATE). + One way to handle this is to only train on the tag labels for the first subtoken of a + split token. We can do this in huggingface Transformers by setting the labels + we wish to ignore to -100. In the example above, if the label for 2080 is B-DATE + and say the id (from the label to id mapping) for B-DATE is 3, we would set the labels + of ['208', '##0'] to [3, -100]. This tells the model to ignore the tokens labelled with + -100 while updating the weights etc. + """ + + def __init__( + self, + tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer], + token_column: str, + label_column: str, + label_to_id: Mapping[str, int], + b_to_i_label: Sequence[int], + padding: Union[bool, str], + truncation: Union[bool, str], + is_split_into_words: bool, + max_length: Optional[int], + label_all_tokens: bool, + token_ignore_label: Optional[str] + ) -> NoReturn: + """ + Set the tokenizer we are using to subword tokenizer our dataset. The name of the + column that contains the pre-split tokens, the name of the column that contains + the labels for each token, label to id mapping. + Set the padding strategy of the input. Set whether to truncate the input tokens. + Indicate whether the input is pre-split into tokens. Set the max length of the + input tokens (post subword tokenization). This will be used in conjunction with truncation. + Set whether we want to label even the sub tokens + In the description above we say for 2080 (B-DATE) - [208, ##0] + We do [3, -100] - which says assume to label of token 2080 is the one + predicted for 208 or we can just label both sub tokens + in which case it would be [3, 3] - so we would label 208 as DATE + and ##0 as DATE - then we would have to figure out how to merge these + labels etc + Args: + tokenizer (Union[PreTrainedTokenizerFast, PreTrainedTokenizer]): Tokenizer from huggingface + token_column (str): The column that contains the tokens in the dataset + label_column (str): The column that contains the labels in the dataset + label_to_id (Mapping[str, int]): The mapping between labels and ID + b_to_i_label (Sequence[int]): The mapping between labels and ID + padding (Union[bool, str]): Padding strategy + truncation (Union[bool, str]): Truncation strategy + is_split_into_words (bool): Is the input pre-split(tokenized) + max_length (Optional[int]): Max subword tokenized length for the model + label_all_tokens (bool): Whether to label sub words + token_ignore_label (str): The value of the token ignore label - we ignore these in the loss computation + """ + self._tokenizer = tokenizer + self._token_column = token_column + self._label_column = label_column + self._label_to_id = label_to_id + self._b_to_i_label = b_to_i_label + # We can tell the tokenizer that we’re dealing with ready-split tokens rather than full + # sentence strings by passing is_split_into_words=True. + # Set the following parameters using the kwargs + self._padding = padding + self._truncation = truncation + self._is_split_into_words = is_split_into_words + self._max_length = max_length + self._label_all_tokens = label_all_tokens + self._token_ignore_label = token_ignore_label + self._ignore_label = -100 + + def tokenize_and_align_labels(self, dataset: Dataset) -> Dataset: + """ + This function is the one that is used to read the input dataset + Run the subword tokenization on the pre-split tokens and then + as mentioned above align the subtokens and labels and add the ignore + label. This will read the input - say [60, year, old, in, 2080] + and will return the subtokens - [60, year, old, in, 208, ##0] + some other information like token_type_ids etc + and the labels [0, 20, 20, 20, 3, -100] (0 - corresponds to B-AGE, 20 corresponds to O + and 3 corresponds to B-DATE. This returned input serves as input for training the model + or for gathering predictions from a trained model. + Another important thing to note is that we have mentioned before that + we add chunks of tokens that appear before and after the current chunk for context. We would + also need to assign the label -100 (ignore_label) to these chunks, since we are using them + only to provide context. Basically if a token has the label NA, we don't use it for + training or evaluation. For example the input would be something + like tokens: [James, Doe, 60, year, old, in, 2080, BWH, tomorrow, only], + labels: [NA, NA, B-DATE, O, O, O, B-DATE, NA, NA, NA]. NA represents the tokens used for context + This function would return some tokenizer info (e.g attention mask etc), along with + the information that maps the tokens to the subtokens - + [James, Doe, 60, year, old, in, 208, ##0, BW, ##h, tomorrow, only] + and the labels - [-100, -100, 0, 20, 20, 20, 3, -100, -100, -100] + (if label_all_tokens was true, we would return [-100, -100, 0, 20, 20, 20, 3, 3, -100, -100]). + Args: + dataset (Dataset): The pre-split (tokenized dataset) that contain labels + Returns: + tokenized_inputs (Dataset): Subword tokenized and label aligned dataset + """ + # Run the tokenizer - subword tokenization + tokenized_inputs = self._tokenizer( + dataset[self._token_column], + padding=self._padding, + truncation=self._truncation, + max_length=self._max_length, + is_split_into_words=self._is_split_into_words, + ) + # Align the subwords and tokens + labels = [self.__get_labels( + labels, + tokenized_inputs.word_ids(batch_index=index) + ) for index, labels in enumerate(dataset[self._label_column])] + tokenized_inputs[self._label_column] = labels + + return tokenized_inputs + + def __get_labels( + self, + labels: Sequence[str], + word_ids: Sequence[int] + ) -> List[int]: + """ + Go thorough the subword tokens - which are given as word_ids. 2 different tokens + 2080 & John will have different word_ids, but the subword tokens 2080 & ##0 will + have the same word_id, we use this to align and assign the labels accordingly. + if the subword tokens belong to [CLS], [SEP] append the ignore label (-100) to the + list of labels. If the (2080) subword token (##0) belongs to a token - 2080 + then the labels would be [3, -100] if label_all_tokens is false. Also if the token + is used only for context (with label NA) it would get the value -100 for its label + Args: + labels (Sequence[str]): The list of labels for the input (example) + word_ids (Sequence[int]): The word_ids after subword tokenization of the input + Returns: + label_ids (List[int]): The list of label ids for the input with the ignore label (-100) added + as required. + """ + label_ids = list() + previous_word_idx = None + for word_idx in word_ids: + # Special tokens have a word id that is None. We set the label to -100 so they are automatically + # ignored in the loss function. + if word_idx is None: + label_ids.append(self._ignore_label) + # We set the label for the first token of each word. + elif word_idx != previous_word_idx: + if labels[word_idx] == self._token_ignore_label: + label_ids.append(self._ignore_label) + else: + label_ids.append(self._label_to_id[labels[word_idx]]) + # For the other tokens in a word, we set the label to either the current label or -100, depending on + # the label_all_tokens flag. + else: + if labels[word_idx] == self._token_ignore_label: + label_ids.append(self._ignore_label) + else: + label_ids.append( + self._b_to_i_label[self._label_to_id[labels[word_idx]]] + if self._label_all_tokens else self._ignore_label + ) + previous_word_idx = word_idx + return label_ids diff --git a/sequence_tagging/dataset_builder/label_mapper.py b/sequence_tagging/dataset_builder/label_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..6b1417f6980338af1c75b50cd726402b804ead23 --- /dev/null +++ b/sequence_tagging/dataset_builder/label_mapper.py @@ -0,0 +1,87 @@ +from typing import List, Sequence, Mapping, Optional, NoReturn, Dict, Union +from .ner_labels import NERLabels + + +class LabelMapper(object): + """ + This class is used to map one set of NER labels to another set of NER labels + For example we might want to map all NER labels to Binary HIPAA labels. + E.g: + We change the token labels - [B-AGE, O, O, U-LOC, B-DATE, L-DATE, O, B-STAFF, I-STAFF, L-STAFF] to + [B-HIPAA, O, O, U-HIPAA, B-HIPAA, I-HIPAA, O, O, O, O] or if we wanted binary I2B2 labels we map it to + [B-I2B2, O, O, U-I2B2, B-I2B2, -I2B2, O, B-I2B2, I-I2B2, L-I2B2] + We do this mapping at the token and the span level. That is we have a span from says start=9, end=15 + labelled as LOC, we map this label to HIPAA or I2B2. This class maps an exisitng set of labels to + another set of labels + """ + + def __init__( + self, + notation, + ner_types: Sequence[str], + ner_types_maps: Sequence[str], + description: str + ) -> NoReturn: + """ + Initialize the variables that will be used to map the NER labels and spans + The ner_map and spans_map should correspond to each other and contain the same NER types + Args: + """ + self._description = description + self._types = list(set(ner_types_maps)) + self._types.sort() + self._spans_map = {ner_type: ner_type_map for ner_type, ner_type_map in zip(ner_types, ner_types_maps)} + ner_labels = NERLabels(notation=notation, ner_types=ner_types) + self._ner_map = dict() + for label in ner_labels.get_label_list(): + if label == 'O' or self._spans_map[label[2:]] == 'O': + self._ner_map[label] = 'O' + else: + self._ner_map[label] = label[0:2] + self._spans_map[label[2:]] + + def map_sequence(self, tag_sequence: Sequence[str]) -> List[str]: + """ + Mapping a sequence of NER labels to another set of NER labels. + E.g: If we use a binary HIPAA mapping + This sequence [B-AGE, O, O, U-LOC, B-DATE, L-DATE, O, B-STAFF, I-STAFF, L-STAFF] will be mapped to + [B-HIPAA, O, O, U-HIPAA, B-HIPAA, I-HIPAA, O, O, O, O] + Return the original sequence if no mapping is used (i.e the maps are == None) + Args: + tag_sequence (Sequence[str]): A sequence of NER labels + Returns: + (List[str]): A mapped sequence of NER labels + """ + # Return the original sequence if no mapping is used + return [self._ner_map[tag] for tag in tag_sequence] + + def map_spans(self, spans: Sequence[Mapping[str, Union[str, int]]]) -> Sequence[Dict[str, Union[str, int]]]: + """ + Mapping a sequence of NER spans to another set of NER spans. + E.g: If we use a binary HIPAA mapping + The spans: [{start:0, end:5, label: DATE}, {start:17, end:25, label: STAFF}, {start:43, end:54, label: PATIENT}] + will be mapped to: [{start:0, end:5, label: HIPAA}, {start:17, end:25, label: O}, {start:43, end:54, label: HIPAA}] + Return the original list of spans if no mapping is used (i.e the maps are == None) + Args: + spans (Sequence[Mapping[str, Union[str, int]]]): A sequence of NER spans + Returns: + (Sequence[Mapping[str, Union[str, int]]]): A mapped sequence of NER spans + """ + return [{'start': span['start'], 'end': span['end'], 'label': self._spans_map[span['label']]} \ + for span in spans] + + def get_ner_description(self) -> str: + """ + Get the description of the ner labels and span maps used + Returns: + (str): A description of the label/span maps used + """ + return self._description + + def get_ner_types(self) -> List[str]: + """ + Get the PHI types back from the list of NER labels + [B-AGE, I-AGE, B-DATE, I-DATE ..] ---> [AGE, DATE, ...] + Returns: + ner_types (List[str]): The list of unique NER types + """ + return self._types diff --git a/sequence_tagging/dataset_builder/ner_dataset.py b/sequence_tagging/dataset_builder/ner_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..deeb083c9294060a9ac60529cbc64eb50fa3b839 --- /dev/null +++ b/sequence_tagging/dataset_builder/ner_dataset.py @@ -0,0 +1,102 @@ +from typing import Sequence, Optional, NoReturn + +from datasets import load_dataset, Dataset + + +class NERDataset(object): + """ + This class is a wrapper around the huggingface datasets library + It maintains the train, validation and test datasets based on the + train, validation and test files passed by loading the dataset object + from the file and provides a get function to access each of the datasets. + """ + + def __init__( + self, + train_file: Optional[Sequence[str]] = None, + validation_file: Optional[Sequence[str]] = None, + test_file: Optional[Sequence[str]] = None, + extension: str = 'json', + shuffle: bool = True, + seed: int = 41 + ) -> NoReturn: + """ + Load the train, validation and test datasets from the files passed. Read the files and convert + it into a huggingface dataset. + Args: + train_file (Optional[Sequence[str]]): The list of files that contain train data + validation_file (Optional[Sequence[str]]): The list of files that contain validation data + test_file (Optional[Sequence[str]]): The list of files that contain test data + shuffle (bool): Whether to shuffle the dataset + seed (int): Shuffle seed + + """ + self._datasets = NERDataset.__prepare_data( + train_file, + validation_file, + test_file, + extension, + shuffle, + seed + ) + + @staticmethod + def __prepare_data( + train_file: Optional[Sequence[str]], + validation_file: Optional[Sequence[str]], + test_file: Optional[Sequence[str]], + extension: str, + shuffle: bool, + seed: int + ) -> Dataset: + """ + Get the train, validation and test datasets from the files passed. Read the files and convert + it into a huggingface dataset. + Args: + train_file (Optional[Sequence[str]]): The list of files that contain train data + validation_file (Optional[Sequence[str]]): The list of files that contain validation data + test_file (Optional[Sequence[str]]): The list of files that contain test data + shuffle (bool): Whether to shuffle the dataset + seed (int): Shuffle seed + Returns: + (Dataset): The huggingface dataset with train, validation, test splits (if included) + """ + # Read the datasets (train, validation, test etc). + data_files = {} + if train_file is not None: + data_files['train'] = train_file + if validation_file is not None: + data_files['validation'] = validation_file + if test_file is not None: + data_files['test'] = test_file + # Shuffle the dataset + if shuffle: + datasets = load_dataset(extension, data_files=data_files).shuffle(seed=seed) + else: + # Don't shuffle the dataset + datasets = load_dataset(extension, data_files=data_files) + return datasets + + def get_train_dataset(self) -> Dataset: + """ + Return the train dataset + Returns: + (Dataset): The huggingface dataset - train split + """ + return self._datasets['train'] + + def get_validation_dataset(self) -> Dataset: + """ + Return the validation dataset + Returns: + (Dataset): The huggingface dataset - validation split + """ + return self._datasets['validation'] + + def get_test_dataset(self) -> Dataset: + """ + Return the test dataset + Returns: + (Dataset): The huggingface dataset - test split + """ + return self._datasets['test'] diff --git a/sequence_tagging/dataset_builder/ner_labels.py b/sequence_tagging/dataset_builder/ner_labels.py new file mode 100644 index 0000000000000000000000000000000000000000..55db1d766009d88495f69921707cbcbbc8b67582 --- /dev/null +++ b/sequence_tagging/dataset_builder/ner_labels.py @@ -0,0 +1,67 @@ +from typing import Sequence, List, NoReturn, Dict + + +class NERLabels(object): + """ + Prepare the labels that will be used by the model. Parse the NER types + and prepare the NER labels. For example - NER Types: [AGE, DATE], + it will create a list like this (for BIO notation) [B-AGE, I-AGE, B-DATE, I-DATE, O] + These are the labels that will be assigned to the tokens based on the PHI type. + Say we had the following NER types: NAME, AGE, HOSP + The NER labels in the BIO notation would be B-AGE, B-HOSP, B-NAME, I-AGE, I-HOSP, I-NAME, O + This script creates a list of the NER labels ([B-AGE, B-HOSP, B-NAME, I-AGE, I-HOSP, I-NAME, O]) + based on the NER types (NAME, AGE, HOSP) that have been defined. Labels have been sorted. + The script also returns the number of labels, the label_to_id mapping, the id_to_label mapping + Label_id_mapping: {B-AGE:0, B-HOSP:1, B-NAME:2, I-AGE:3, I-HOSP:4, I-NAME:5, O:6} + This information will be used during training, evaluation and prediction. + """ + + def __init__(self, notation: str, ner_types: Sequence[str]) -> NoReturn: + """ + Initialize the notation that we are using for the NER task + Args: + notation (str): The notation that will be used for the NER labels + ner_types (Sequence[str]): The list of NER categories + """ + self._notation = notation + self._ner_types = ner_types + + def get_label_list(self) -> List[str]: + """ + Given the NER types return the NER labels. + NER Types: [AGE, DATE] -> return a list like this (for BIO notation) [B-AGE, I-AGE, B-DATE, I-DATE, O] + Returns: + ner_labels (List[str]): The list of NER labels based on the NER notation (e.g BIO) + """ + # Add the 'O' (Outside - Non-phi) label to the list + if 'O' not in self._ner_types: + ner_labels = ['O'] + else: + ner_labels = list() + # Go through each label and prefix it based on the notation (e.g - B, I etc) + for ner_type in self._ner_types: + for ner_tag in list(self._notation): + if ner_tag != 'O': + ner_labels.append(ner_tag + '-' + ner_type) + ner_labels.sort() + return ner_labels + + def get_label_to_id(self) -> Dict[str, int]: + """ + Return a label to id mapping + Returns: + label_to_id (Dict[str, int]): label to id mapping + """ + labels = self.get_label_list() + label_to_id = {label: index_id for index_id, label in enumerate(labels)} + return label_to_id + + def get_id_to_label(self) -> Dict[int, str]: + """ + Return a id to label mapping + Returns: + id_to_label (Dict[int, str]): id to label mapping + """ + labels = self.get_label_list() + id_to_label = {index_id: label for index_id, label in enumerate(labels)} + return id_to_label diff --git a/sequence_tagging/evaluation/__init__.py b/sequence_tagging/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6098744e6e82755996ad85c6c39dfd103b6b5a3f --- /dev/null +++ b/sequence_tagging/evaluation/__init__.py @@ -0,0 +1 @@ +from .metrics_compute import MetricsCompute \ No newline at end of file diff --git a/sequence_tagging/evaluation/metrics_compute.py b/sequence_tagging/evaluation/metrics_compute.py new file mode 100644 index 0000000000000000000000000000000000000000..148e8b42b85f700dc6c1663afd6d9b919625e72f --- /dev/null +++ b/sequence_tagging/evaluation/metrics_compute.py @@ -0,0 +1,202 @@ +from typing import Sequence, Tuple, Dict, NoReturn, Mapping, Union, Type + +from seqeval.scheme import IOB1, IOB2, IOBES, BILOU + + +class MetricsCompute(object): + """ + This is the evaluation script which is passed to the huggingface + trainer - specifically the compute_metrics function. The trainer uses + this function to run the evaluation on the validation dataset and log/save + the metrics. This script is used to evaluate the token and span level metrics + on the validation dataset by the huggingface trainer. The evaluation is also run + on the NER labels and spans produced by the different label mapper + objects. For example we might run the evaluation on the original list of NER labels/spans + and we also run the evaluation on binary HIPAA labels/spans. This is done by mapping the + NER labels & spans using the list of label_mapper object present in label_mapper_list + The same evaluation script and metrics are first run on the original ner types/labels/spans + e.g: + [AGE, STAFF, DATE], [B-AGE, O, O, U-LOC, B-DATE, L-DATE, O, B-STAFF, I-STAFF, L-STAFF], + [{start:0, end:5, label: AGE}, {start:17, end:25, label: LOC}, {start:43, end:54, label: DATE}, + {start:77, end:84, label: STAFF}] + and we also run on some mapped version of these ner types/labels/spans shown below + [HIPAA], [B-HIPAA, O, O, U-HIPAA, B-HIPAA, I-HIPAA, O, O, O, O], [{start:0, end:5, label: HIPAA}, + {start:17, end:25, label: HIPAA}, {start:43, end:54, label: HIPAA}, {start:77, end:84, label: O}] + The context and subword tokens are excluded from the evaluation process + The results are returned - which are saved and logged + """ + + def __init__( + self, + metric, + note_tokens: Sequence[Sequence[Mapping[str, Union[str, int]]]], + note_spans: Sequence[Sequence[Mapping[str, Union[str, int]]]], + label_mapper_list: Sequence, + post_processor, + note_level_aggregator, + notation: str, + mode: str, + confusion_matrix: bool = False, + format_results: bool = True + ) -> NoReturn: + """ + Initialize the variables used ot perform evaluation. The evaluation object. + How the model predictions are decoded (e.g argmax, crf). The post processor object + also handles excluding context and subword tokens are excluded from the evaluation process + The notation, evaluation mode label maps. The note_tokens is used in the span level evaluation + process to check the character position of each token - and check if they match with the character + position of the spans. The note_spans are also used in the span level evaluation process, they contain + the position and labels of the spans. + Args: + metric (): The huggingface metric object, which contains the span and token level evaluation code + note_tokens (Sequence[Sequence[Mapping[str, Union[str, int]]]]): The list of tokens in the entire dataset + note_spans (Sequence[Sequence[Mapping[str, Union[str, int]]]]): The list of note spans in the entire dataset + post_processor (): Post processing the predictions (logits) - argmax, or crf etc. The prediction logits are + passed to this object, which is then processed using the argmax of the logits or a + crf function to return the sequence of NER labels + note_level_aggregator (): Aggregate sentence level predictions back to note level for evaluation + using this object + label_mapper_list (Sequence): The list of label mapper object that are used to map ner spans and + labels for evaluation + notation (str): The NER notation + mode (str): The span level eval mode - strict or default + format_results (bool): Format the results - return either a single dict (true) or a dict of dicts (false) + """ + self._metric = metric + self._note_tokens = note_tokens + self._note_spans = note_spans + self._label_mapper_list = label_mapper_list + self._note_level_aggregator = note_level_aggregator + self._notation = notation + self._scheme = MetricsCompute.get_scheme(self._notation) + self._mode = mode + self._post_processor = post_processor + self._confusion_matrix = confusion_matrix + self._format_results = format_results + + @staticmethod + def get_scheme(notation: str) -> Union[Type[IOB2], Type[IOBES], Type[BILOU], Type[IOB1]]: + """ + Get the seqeval scheme based on the notation + Args: + notation (str): The NER notation + Returns: + (Union[IOB2, IOBES, BILOU, IOB1]): The seqeval scheme + """ + if notation == 'BIO': + return IOB2 + elif notation == 'BIOES': + return IOBES + elif notation == 'BILOU': + return BILOU + elif notation == 'IO': + return IOB1 + else: + raise ValueError('Invalid Notation') + + def run_metrics( + self, + note_labels: Sequence[Sequence[str]], + note_predictions: Sequence[Sequence[str]] + ) -> Union[Dict[str, Union[int, float]], Dict[str, Dict[str, Union[int, float]]]]: + """ + Run the evaluation metrics and return the span and token level results. + The metrics are run for each mapping of ner labels - based on the object in the + label_mapper_list. The evaluation is also run on the NER labels and spans produced by the different + label mapper objects. For example we might run the evaluation on the original list of NER labels/spans + and we also run the evaluation on binary HIPAA labels/spans. This is done by mapping the + NER labels & spans using the list of label_mapper object present in label_mapper_list + The same evaluation script and metrics are first run on the original ner types/labels/spans + e.g: + [AGE, STAFF, DATE], [B-AGE, O, O, U-LOC, B-DATE, L-DATE, O, B-STAFF, I-STAFF, L-STAFF], + [{start:0, end:5, label: AGE}, {start:17, end:25, label: LOC}, {start:43, end:54, label: DATE}, + {start:77, end:84, label: STAFF}] + and we also run on some mapped version of these ner types/labels/spans shown below + [HIPAA], [B-HIPAA, O, O, U-HIPAA, B-HIPAA, I-HIPAA, O, O, O, O], [{start:0, end:5, label: HIPAA}, + {start:17, end:25, label: HIPAA}, {start:43, end:54, label: HIPAA}, {start:77, end:84, label: O}] + Args: + note_labels (Sequence[Sequence[str]]): The list of NER labels for each note + note_predictions (Sequence[Sequence[str]]): The list of NER predictions for each notes + Returns: + final_results (Union[Dict[str, Union[int, float]], Dict[str, Dict[str, Union[int, float]]]]): Span and token + level + metric results + """ + final_results = {} + # Go through the list of different mapping (e.g HIPAA/I2B2) + for label_mapper in self._label_mapper_list: + # Get the NER information + ner_types = label_mapper.get_ner_types() + ner_description = label_mapper.get_ner_description() + # Map the NER labels and spans + predictions = [label_mapper.map_sequence(prediction) for prediction in note_predictions] + labels = [label_mapper.map_sequence(label) for label in note_labels] + spans = [label_mapper.map_spans(span) for span in self._note_spans] + # Run the span level and token level evaluation metrics + results = self._metric.compute( + predictions=predictions, + references=labels, + note_tokens=self._note_tokens, + note_spans=spans, + ner_types=ner_types, + ner_description=ner_description, + notation=self._notation, + scheme=self._scheme, + mode=self._mode, + confusion_matrix=self._confusion_matrix + ) + # Return the results as a single mapping or a nested mapping + if not self._format_results: + for key, value in results.items(): + final_results[key] = value + else: + for key, value in results.items(): + if isinstance(value, dict): + for n, v in value.items(): + final_results[f"{key}_{n}"] = v + else: + final_results[key] = value + # Return the results + return final_results + + def compute_metrics( + self, + p: Tuple[Sequence[Sequence[str]], Sequence[Sequence[str]]] + ) -> Union[Dict[str, Union[int, float]], Dict[str, Dict[str, Union[int, float]]]]: + """ + This script is used to compute the token and span level metrics when + the predictions and labels are passed. The first step is to convert the + model logits into the sequence of NER predictions using the post_processor + object (argmax, crf etc) and also exclude any context and subword tokens from the + evaluation process. Once we have the NER labels and predictions we run + the span and token level evaluation. + The evaluation is also run on the NER labels and spans produced by the different label mapper + objects. For example we might run the evaluation on the original list of NER labels/spans + and we also run the evaluation on binary HIPAA labels/spans. This is done by mapping the + NER labels & spans using the list of label_mapper object present in label_mapper_list + The same evaluation script and metrics are first run on the original ner types/labels/spans + e.g: + [AGE, STAFF, DATE], [B-AGE, O, O, U-LOC, B-DATE, L-DATE, O, B-STAFF, I-STAFF, L-STAFF], + [{start:0, end:5, label: AGE}, {start:17, end:25, label: LOC}, {start:43, end:54, label: DATE}, + {start:77, end:84, label: STAFF}] + and we also run on some mapped version of these ner types/labels/spans shown below + [HIPAA], [B-HIPAA, O, O, U-HIPAA, B-HIPAA, I-HIPAA, O, O, O, O], [{start:0, end:5, label: HIPAA}, + {start:17, end:25, label: HIPAA}, {start:43, end:54, label: HIPAA}, {start:77, end:84, label: O}] + Run the evaluation metrics and return the span and token level results. + The metrics are run for each mapping of ner labels - based on the object in the + label_mapper_list + Args: + p (Tuple[Sequence[Sequence[str]], Sequence[Sequence[str]]]): Tuple of model logits and labels + Returns: + final_results (Union[Dict[str, Union[int, float]], Dict[str, Dict[str, Union[int, float]]]]): Span and token + level + metric results + """ + predictions, labels = p + # Convert the logits (scores) to predictions + true_predictions, true_labels = self._post_processor.decode(predictions, labels) + # Aggregate sentence level labels and predictions to note level for evaluation + note_predictions = self._note_level_aggregator.get_aggregate_sequences(true_predictions) + note_labels = self._note_level_aggregator.get_aggregate_sequences(true_labels) + # Return results + return self.run_metrics(note_labels, note_predictions) diff --git a/sequence_tagging/evaluation/note_evaluation/__init__.py b/sequence_tagging/evaluation/note_evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sequence_tagging/evaluation/note_evaluation/note_evaluation.py b/sequence_tagging/evaluation/note_evaluation/note_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..489ff028cba134f856880e0cc60c6304a079c5d0 --- /dev/null +++ b/sequence_tagging/evaluation/note_evaluation/note_evaluation.py @@ -0,0 +1,223 @@ +""" modified seqeval metric. """ +from typing import Sequence, List, Optional, Type, Union, Mapping, Dict + +# This script uses the two other scripts note_sequence_evaluation.py +# and note_token_evalaution.py to gather the span level and token +# level metrics during the evaluation phase in the huggingface +# training process. More information on how this script works +# can be found in - https://github.com/huggingface/datasets/tree/master/metrics/seqeval +# The code is borrowed from there and minor changes are made - to include token +# level metrics and evaluating spans at the character level as opposed to the +# token level +import datasets + +from .note_sequence_evaluation import NoteSequenceEvaluation +from .note_token_evaluation import NoteTokenEvaluation +from .violations import Violations + +_CITATION = """\ +@inproceedings{ramshaw-marcus-1995-text, + title = "Text Chunking using Transformation-Based Learning", + author = "Ramshaw, Lance and + Marcus, Mitch", + booktitle = "Third Workshop on Very Large Corpora", + year = "1995", + url = "https://www.aclweb.org/anthology/W95-0107", +} +@misc{seqeval, + title={{seqeval}: A Python framework for sequence labeling evaluation}, + url={https://github.com/chakki-works/seqeval}, + note={Software available from https://github.com/chakki-works/seqeval}, + author={Hiroki Nakayama}, + year={2018}, +} +""" + +_DESCRIPTION = """seqeval is a Python framework for sequence labeling evaluation. seqeval can evaluate the +performance of chunking tasks such as named-entity recognition, part-of-speech tagging, semantic role labeling and so +on. This is well-tested by using the Perl script conlleval, which can be used for measuring the performance of a +system that has processed the CoNLL-2000 shared task data. seqeval supports following formats: IOB1 IOB2 IOE1 IOE2 +IOBES See the [README.md] file at https://github.com/chakki-works/seqeval for more information. """ + +_KWARGS_DESCRIPTION = """ +Produces labelling scores along with its sufficient statistics +from a source against one or more references. +Args: + predictions: List of List of predicted labels (Estimated targets as returned by a tagger) + references: List of List of reference labels (Ground truth (correct) target values) + suffix: True if the IOB prefix is after type, False otherwise. default: False +Returns: + 'scores': dict. Summary of the scores for overall and per type + Overall: + 'accuracy': accuracy, + 'precision': precision, + 'recall': recall, + 'f1': F1 score, also known as balanced F-score or F-measure, + Per type: + 'precision': precision, + 'recall': recall, + 'f1': F1 score, also known as balanced F-score or F-measure +Examples: + >>> predictions = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']] + >>> references = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']] + >>> seqeval = datasets.load_metric("seqeval") + >>> results = seqeval.compute(predictions=predictions, references=references) + >>> print(list(results.keys())) + ['MISC', 'PER', 'overall_precision', 'overall_recall', 'overall_f1', 'overall_accuracy'] + >>> print(results["overall_f1"]) + 0.5 + >>> print(results["PER"]["f1"]) + 1.0 +""" + + +@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) +class NoteEvaluation(datasets.Metric): + + def _info(self): + return datasets.MetricInfo( + description=_DESCRIPTION, + citation=_CITATION, + features=datasets.Features( + { + "references": datasets.Sequence(datasets.Value("string", id="label"), id="sequence"), + "predictions": datasets.Sequence(datasets.Value("string", id="label"), id="sequence"), + } + ), + inputs_description=_KWARGS_DESCRIPTION + ) + + def _compute( + self, + references: Sequence[Sequence[str]], + predictions: Sequence[Sequence[str]], + note_tokens: Sequence[Sequence[Mapping[str, Union[str, int]]]], + note_spans: Sequence[Sequence[Mapping[str, Union[str, int]]]], + ner_types: Sequence[str], + ner_description: str, + notation: str, + scheme: str, + mode: str, + confusion_matrix: bool = False, + suffix: bool = False, + sample_weight: Optional[List[int]] = None, + zero_division: Union[str, int] = "warn", + ) -> Dict[str, Dict[str, Union[int, float]]]: + """ + Use the NoteSequenceEvaluation and NoteTokenEvaluation classes to extract the + token and span level precision, recall and f1 scores. Also return the micro averaged + precision recall and f1 scores + Args: + references (Sequence[Sequence[str]]): The list of annotated labels in the evaluation dataset + predictions (Sequence[Sequence[str]]): The list of predictions in the evaluation dataset + note_tokens (Sequence[Sequence[Mapping[str, Union[str, int]]]]): The list of tokens for the notes + in the evaluation dataset + note_spans (Sequence[Sequence[Mapping[str, Union[str, int]]]]): The list of annotated spans for the notes + in the evaluation dataset + ner_types (Sequence[str]): The list of NER types e.g AGE, DATE etc + ner_description (str): A prefix added to the evaluation result keys + scheme (Type[Token]): The NER labelling scheme + mode (str): Whether to use default or strict evaluation + suffix (str): Whether the B, I etc is in the prefix or the suffix + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + zero_division : "warn", 0 or 1, default="warn" + Sets the value to return when there is a zero division: + - recall: when there are no positive labels + - precision: when there are no positive predictions + - f-score: both + If set to "warn", this acts as 0, but warnings are also raised. + Returns: + (Dict[str, Dict[str, Union[int, float]]]): The token and span level metric scores + """ + # Span level metrics scores + report = NoteSequenceEvaluation.classification_report( + note_predictions=predictions, + note_tokens=note_tokens, + note_spans=note_spans, + ner_types=ner_types, + scheme=scheme, + mode=mode, + suffix=suffix, + output_dict=True, + sample_weight=sample_weight, + zero_division=zero_division, + ) + # Token level metric scores + token_report = NoteTokenEvaluation.classification_report( + labels=references, + predictions=predictions, + ner_types=ner_types + ) + violation_count = sum([Violations.get_violations(tag_sequence=prediction, notation=notation) + for prediction in predictions]) + # Remove the macro and weighted average results + macro_score = report.pop("macro avg") + report.pop("weighted avg") + macro_token_score = token_report.pop("macro avg") + token_report.pop("weighted avg") + overall_score = report.pop("micro avg") + token_overall_score = token_report.pop("micro avg") + # Extract span level scores for each NER type + scores = { + type_name: { + "precision": score["precision"], + "recall": score["recall"], + "f1": score["f1-score"], + "number": score["support"], + } + for type_name, score in report.items() + } + # Extract token level scores for each NER type + token_scores = { + type_name + '-TOKEN': { + "precision": score["precision"], + "recall": score["recall"], + "f1": score["f1-score"], + "number": score["support"], + } + for type_name, score in token_report.items() + } + # Extract micro averaged span level score + overall = {'overall' + ner_description: + {"precision": overall_score["precision"], + "recall": overall_score["recall"], + "f1": overall_score["f1-score"], + } + } + # Extract micro averaged token level score + token_overall = {'token-overall' + ner_description: + {"precision": token_overall_score["precision"], + "recall": token_overall_score["recall"], + "f1": token_overall_score["f1-score"], + } + } + # Extract macro averaged token level score + macro_overall = {'macro-overall' + ner_description: + {"precision": macro_score["precision"], + "recall": macro_score["recall"], + "f1": macro_score["f1-score"], + } + } + # Extract macro averaged token level score + macro_token_overall = {'macro-token-overall' + ner_description: + {"precision": macro_token_score["precision"], + "recall": macro_token_score["recall"], + "f1": macro_token_score["f1-score"], + } + } + # Store number of NER violations + violation_count = {'violations' + ner_description: {'count': violation_count}} + # Return the results + if confusion_matrix: + confusion_matrix = {'confusion' + ner_description: + {'matrix': NoteTokenEvaluation.get_confusion_matrix( + labels=references, + predictions=predictions, + ner_types=ner_types + )}} + return {**scores, **overall, **token_scores, **token_overall, **macro_overall, **macro_token_overall, + **violation_count, **confusion_matrix} + else: + return {**scores, **overall, **token_scores, **token_overall, **macro_overall, **macro_token_overall, + **violation_count} diff --git a/sequence_tagging/evaluation/note_evaluation/note_evaluation.py.lock b/sequence_tagging/evaluation/note_evaluation/note_evaluation.py.lock new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sequence_tagging/evaluation/note_evaluation/note_sequence_evaluation.py b/sequence_tagging/evaluation/note_evaluation/note_sequence_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..1d4ef6c2c54e6567677cb0963537151c9f939768 --- /dev/null +++ b/sequence_tagging/evaluation/note_evaluation/note_sequence_evaluation.py @@ -0,0 +1,516 @@ +# Script to evaluate at a spans level +# Sequence evaluation - code is based on the seqeval library/package +# While seqeval evaluates at the token position level, we evalaute at a +# character position level. Since most of the code is the same, refer +# to the library/github repo of the seqeval package fore more details. +import warnings +from collections import defaultdict +from typing import Sequence, List, Optional, Tuple, Type, Union, Mapping + +import numpy as np +from seqeval.metrics.sequence_labeling import get_entities +from seqeval.reporters import DictReporter, StringReporter +from seqeval.scheme import Entities, Token +from sklearn.exceptions import UndefinedMetricWarning + +PER_CLASS_SCORES = Tuple[List[float], List[float], List[float], List[int]] +AVERAGE_SCORES = Tuple[float, float, float, int] +SCORES = Union[PER_CLASS_SCORES, AVERAGE_SCORES] + + +class NoteSequenceEvaluation(object): + """ + There already exists a package (seqeval) that can do the sequence evaluation. + The reason we have this class is that the package seqeval looks at it from a + token level classification perspective. So it evaluates if the spans formed by + the token classification (predictions) matches/not matches spans formed by the labels. + But in medical notes, there are many cases where the token and label dont align + E.g inboston - is one token, but the LOC span is in[boston], it does not cover the + entire token. Since seqeval evaluates at a token level, it makes it hard to evaluate models + or penalize models that dont handle tokenization issues. This evaluation script is used + to evaluate the model at a character level. We essentially see if the character positions + line up, as opposed to token positions, in which case we can handle evaluation of cases + like inboston. We borrow most of the code and intuition from seqeval and make changes + where necessary to suit our needs. + """ + + @staticmethod + def extract_predicted_spans_default( + tokens: Sequence[Mapping[str, Union[str, int]]], + predictions: Sequence[str], + suffix: str + ) -> defaultdict(set): + """ + Use the seqeval get_entities method, which goes through the predictions and returns + where the span starts and ends. - [O, O, B-AGE, I-AGE, O, O] this will return + spans starts at token 2 and ends at token 3 - with type AGE. We then extract the + position of the token in the note (character position) - so we return that + this span starts at 32 and ends at 37. The function then returns a dict + where the keys are the NER types and the values are the list of different + positions these types occur within the note. + Args: + tokens (Sequence[Mapping[str, Union[str, int]]]): The list of tokens in the note + predictions (Sequence[str]): The list of predictions for the note + suffix (str): Whether the B, I etc is in the prefix or the suffix + Returns: + entities_pred (defaultdict(set)): Keys are the NER types and the value is a set that + contains the positions of these types + """ + entities_pred = defaultdict(set) + for type_name, start, end in get_entities(predictions, suffix=suffix): + entities_pred[type_name].add((tokens[start]['start'], tokens[end]['end'])) + return entities_pred + + @staticmethod + def extract_predicted_spans_strict( + tokens: Sequence[Mapping[str, Union[str, int]]], + predictions: Sequence[str], + ner_types: Sequence[str], + scheme: Type[Token], + suffix: str + ) -> defaultdict(set): + """ + Use the seqeval get_entities method, which goes through the predictions and returns + where the span starts and ends. - [O, O, B-AGE, I-AGE, O, O] this will return + spans starts at token 2 and ends at token 3 - with type AGE. We then extract the + position of the token in the note (character position) - so we return that + this span starts at 32 and ends at 37. The function then returns a dict + where the keys are the NER types and the values are the set of different + positions these types occur within the note. The difference with the + extract_predicted_spans_default function is this is more strict in that + the spans needs to start with B tag and other constraints depending on the scheme + Args: + tokens (Sequence[Mapping[str, Union[str, int]]]): The list of tokens in the note + predictions (Sequence[str]): The list of predictions for the note + scheme (Type[Token]): The NER labelling scheme + suffix (str): Whether the B, I etc is in the prefix or the suffix + Returns: + entities_pred (defaultdict(set)): Keys are the NER types and the value is a set that + contains the positions of these types + """ + entities_pred = defaultdict(set) + entities = Entities(sequences=[predictions], scheme=scheme, suffix=suffix) + for tag in ner_types: + for entity in entities.filter(tag): + entities_pred[entity.tag].add((tokens[entity.start]['start'], tokens[entity.end - 1]['end'])) + return entities_pred + + @staticmethod + def extract_true_spans(note_spans: Sequence[Mapping[str, Union[str, int]]]) -> defaultdict(set): + """ + Go through the list of annotated spans and create a mapping like we do for the other + functions - where the mapping contains keys which are the NER types and the values are the set of different + positions (start, end) these NER types/spans occur within the note. + Args: + note_spans (Sequence[Mapping[str, Union[str, int]]]): The list of spans in the note + Returns: + entities_true (defaultdict(set)): Keys are the NER types and the value is a set that + contains the positions of these types + """ + entities_true = defaultdict(set) + for span in note_spans: + entities_true[span['label']].add((int(span['start']), int(span['end']))) + return entities_true + + @staticmethod + def extract_tp_actual_correct( + note_predictions: Sequence[Sequence[str]], + note_tokens: Sequence[Sequence[Mapping[str, Union[str, int]]]], + note_spans: Sequence[Sequence[Mapping[str, Union[str, int]]]], + ner_types: Sequence[str], + scheme: Type[Token], + mode: str, + suffix: str + ) -> Tuple[List, List, List]: + """ + Extract the the number of gold spans per NER types, the number of predicted spans per + NER type, the number of spans where the gold standard and predicted spans match for all + the notes in the evaluation dataset. This is mainly done by comparing the gold standard span + positions and the predicted span positions using the extract_predicted_spans_default, + extract_predicted_spans_strict, extract_true_spans functions + The annotated spans is a list that contains a list of spans for each note. This list of spans contain + the span label and the position (start, end) of the span in the note (character positions). + We use this as our true labels. The reason we do this, is because for medical notes it's better + to have character level positions, because it makes it easier to evaluate typos. + Note tokens is a list that in turn contains a list of tokens present in the note. For each token + we have it start and end position (character positions) in the note. For evaluation of the model + predictions, the note_spans and note_tokens, remain constant and hence we initialize it here. + We use note tokens to map the predictions of the model to the character positions and then + compare it with the character positions of the annotated spans. + Args: + note_predictions (Sequence[Sequence[str]]): The list of predictions in the evaluation dataset + note_spans (Sequence[Sequence[Mapping[str, Union[str, int]]]]): The list of annotated spans for the notes + in the evaluation dataset + note_tokens (Sequence[Sequence[Mapping[str, Union[str, int]]]]): The list of tokens for the notes + in the evaluation dataset + ner_types (Sequence[str]): The list of NER types e.g AGE, DATE etc + scheme (Type[Token]): The NER labelling scheme + mode (str): Whether to use default or strict evaluation + suffix (str): Whether the B, I etc is in the prefix or the suffix + Returns: + pred_sum (np.array): The number of predicted spans + tp_sum (np.array): The number of predicted spans that match gold standard spans + true_sum (np.array): The number of gold standard spans + """ + # Initialize the arrays that will store the number of predicted spans per NER type + # the gold standard number of spans per NER type and the number of spans that match between + # the predicted and actual (true positives) + tp_sum = np.zeros(len(ner_types), dtype=np.int32) + pred_sum = np.zeros(len(ner_types), dtype=np.int32) + true_sum = np.zeros(len(ner_types), dtype=np.int32) + # Calculate the number of true positives, predicted and actual number of spans per NER type + # for each note and sum up the results + for spans, tokens, predictions in zip(note_spans, note_tokens, note_predictions): + # Get all the gold standard spans + entities_true = NoteSequenceEvaluation.extract_true_spans(note_spans=spans) + # Get all the predicted spans + if mode == 'default': + entities_pred = NoteSequenceEvaluation.extract_predicted_spans_default( + tokens=tokens, + predictions=predictions, + suffix=suffix + ) + elif mode == 'strict': + entities_pred = NoteSequenceEvaluation.extract_predicted_spans_strict( + tokens=tokens, + predictions=predictions, + ner_types=ner_types, + scheme=scheme, + suffix=suffix + ) + else: + raise ValueError('Invalid Mode') + # Calculate and store the number of the gold standard spans, predicted spans and true positives + # for each NER type + for ner_index, ner_type in enumerate(ner_types): + entities_true_type = entities_true.get(ner_type, set()) + entities_pred_type = entities_pred.get(ner_type, set()) + tp_sum[ner_index] += len(entities_true_type & entities_pred_type) + pred_sum[ner_index] += len(entities_pred_type) + true_sum[ner_index] += len(entities_true_type) + return pred_sum, tp_sum, true_sum + + @staticmethod + def precision_recall_fscore( + note_predictions: Sequence[Sequence[str]], + note_tokens: Sequence[Sequence[Mapping[str, Union[str, int]]]], + note_spans: Sequence[Sequence[Mapping[str, Union[str, int]]]], + ner_types: Sequence[str], + scheme: Type[Token], + mode: str, + *, + average: Optional[str] = None, + warn_for=('precision', 'recall', 'f-score'), + beta: float = 1.0, + sample_weight: Optional[List[int]] = None, + zero_division: str = 'warn', + suffix: bool = False + ) -> SCORES: + """ + Extract the precision, recall and F score based on the number of predicted spans per + NER type, the number of spans where the gold standard and predicted spans match for all + the notes in the evaluation dataset. + Return the precision, recall, f1 scores for each NER type and averaged scores (micro, macro etc) + Args: + note_predictions (Sequence[Sequence[str]]): The list of predictions in the evaluation dataset + note_spans (Sequence[Sequence[Mapping[str, Union[str, int]]]]): The list of annotated spans for the notes + in the evaluation dataset + note_tokens (Sequence[Sequence[Mapping[str, Union[str, int]]]]): The list of tokens for the notes + in the evaluation dataset + ner_types (Sequence[str]): The list of NER types e.g AGE, DATE etc + scheme (Type[Token]): The NER labelling scheme + mode (str): Whether to use default or strict evaluation + suffix (str): Whether the B, I etc is in the prefix or the suffix + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + zero_division : "warn", 0 or 1, default="warn" + Sets the value to return when there is a zero division: + - recall: when there are no positive labels + - precision: when there are no positive predictions + - f-score: both + If set to "warn", this acts as 0, but warnings are also raised. + Returns: + (SCORES): Precision, recall, f1 scores for each NER type - and averaged scores (micro, macro etc) + """ + if beta < 0: + raise ValueError('beta should be >=0 in the F-beta score') + + average_options = (None, 'micro', 'macro', 'weighted') + if average not in average_options: + raise ValueError('average has to be one of {}'.format(average_options)) + # Calculate and store the number of the gold standard spans, predicted spans and true positives + # for each NER type - this will be used to calculate the precision, recall and f1 scores + pred_sum, tp_sum, true_sum = NoteSequenceEvaluation.extract_tp_actual_correct( + note_predictions=note_predictions, + note_tokens=note_tokens, + note_spans=note_spans, + ner_types=ner_types, + scheme=scheme, + mode=mode, + suffix=suffix + ) + + if average == 'micro': + tp_sum = np.array([tp_sum.sum()]) + pred_sum = np.array([pred_sum.sum()]) + true_sum = np.array([true_sum.sum()]) + + # Finally, we have all our sufficient statistics. Divide! # + beta2 = beta ** 2 + + # Divide, and on zero-division, set scores and/or warn according to + # zero_division: + precision = NoteSequenceEvaluation._prf_divide( + numerator=tp_sum, + denominator=pred_sum, + metric='precision', + modifier='predicted', + average=average, + warn_for=warn_for, + zero_division=zero_division + ) + recall = NoteSequenceEvaluation._prf_divide( + numerator=tp_sum, + denominator=true_sum, + metric='recall', + modifier='true', + average=average, + warn_for=warn_for, + zero_division=zero_division + ) + + # warn for f-score only if zero_division is warn, it is in warn_for + # and BOTH precision and recall are ill-defined + if zero_division == 'warn' and ('f-score',) == warn_for: + if (pred_sum[true_sum == 0] == 0).any(): + NoteSequenceEvaluation._warn_prf( + average, 'true nor predicted', 'F-score is', len(true_sum) + ) + + # if tp == 0 F will be 1 only if all predictions are zero, all labels are + # zero, and zero_division=1. In all other case, 0 + if np.isposinf(beta): + f_score = recall + else: + denom = beta2 * precision + recall + + denom[denom == 0.] = 1 # avoid division by 0 + f_score = (1 + beta2) * precision * recall / denom + + # Average the results + if average == 'weighted': + weights = true_sum + if weights.sum() == 0: + zero_division_value = 0.0 if zero_division in ['warn', 0] else 1.0 + # precision is zero_division if there are no positive predictions + # recall is zero_division if there are no positive labels + # fscore is zero_division if all labels AND predictions are + # negative + return ( + zero_division_value if pred_sum.sum() == 0 else 0.0, + zero_division_value, + zero_division_value if pred_sum.sum() == 0 else 0.0, + sum(true_sum) + ) + + elif average == 'samples': + weights = sample_weight + else: + weights = None + + if average is not None: + precision = np.average(precision, weights=weights) + recall = np.average(recall, weights=weights) + f_score = np.average(f_score, weights=weights) + true_sum = sum(true_sum) + + return precision, recall, f_score, true_sum + + @staticmethod + def classification_report( + note_predictions, + note_tokens, + note_spans, + ner_types: Sequence[str], + scheme: Type[Token], + mode: str, + *, + sample_weight: Optional[List[int]] = None, + digits: int = 2, + output_dict: bool = False, + zero_division: str = 'warn', + suffix: bool = False + ) -> Union[str, dict]: + """ + Build a text report showing the main tagging metrics. + Args: + note_predictions (Sequence[Sequence[str]]): The list of preditions in the evaluation dataset + note_spans (Sequence[Sequence[Mapping[str, Union[str, int]]]]): The list of annotated spans for the notes + in the evaluation dataset + note_tokens (Sequence[Sequence[Mapping[str, Union[str, int]]]]): The list of tokens for the notes + in the evaluation dataset + ner_types (Sequence[str]): The list of NER types e.g AGE, DATE etc + scheme (Type[Token]): The NER labelling scheme + mode (str): Whether to use default or strict evaluation + suffix (str): Whether the B, I etc is in the prefix or the suffix + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + digits (int): Number of digits for formatting output floating point values. + output_dict (bool(default=False)): If True, return output as dict else str. + zero_division : "warn", 0 or 1, default="warn" + Sets the value to return when there is a zero division: + - recall: when there are no positive labels + - precision: when there are no positive predictions + - f-score: both + If set to "warn", this acts as 0, but warnings are also raised. + Returns: + report : string/dict. Summary of the precision, recall, F1 score for each class. + Examples: + >>> from seqeval.metrics.v1 import classification_report + >>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']] + >>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']] + >>> print(classification_report(y_true, y_pred)) + precision recall f1-score support + + MISC 0.00 0.00 0.00 1 + PER 1.00 1.00 1.00 1 + + micro avg 0.50 0.50 0.50 2 + macro avg 0.50 0.50 0.50 2 + weighted avg 0.50 0.50 0.50 2 + + """ + NoteSequenceEvaluation.check_consistent_length(note_tokens, note_predictions) + if len(note_spans) != len(note_tokens): + raise ValueError('Number of spans and number of notes mismatch') + + if output_dict: + reporter = DictReporter() + else: + name_width = max(map(len, ner_types)) + avg_width = len('weighted avg') + width = max(name_width, avg_width, digits) + reporter = StringReporter(width=width, digits=digits) + + # compute per-class scores. + p, r, f1, s = NoteSequenceEvaluation.precision_recall_fscore( + note_predictions=note_predictions, + note_tokens=note_tokens, + note_spans=note_spans, + ner_types=ner_types, + scheme=scheme, + mode=mode, + average=None, + sample_weight=sample_weight, + zero_division=zero_division, + suffix=suffix + ) + for row in zip(ner_types, p, r, f1, s): + reporter.write(*row) + reporter.write_blank() + + # compute average scores. + average_options = ('micro', 'macro', 'weighted') + for average in average_options: + avg_p, avg_r, avg_f1, support = NoteSequenceEvaluation.precision_recall_fscore( + note_predictions=note_predictions, + note_tokens=note_tokens, + note_spans=note_spans, + ner_types=ner_types, + scheme=scheme, + mode=mode, + average=average, + sample_weight=sample_weight, + zero_division=zero_division, + suffix=suffix) + reporter.write('{} avg'.format(average), avg_p, avg_r, avg_f1, support) + reporter.write_blank() + + return reporter.report() + + @staticmethod + def _prf_divide( + numerator, + denominator, + metric, + modifier, + average, + warn_for, + zero_division='warn' + ): + """ + Performs division and handles divide-by-zero. + On zero-division, sets the corresponding result elements equal to + 0 or 1 (according to ``zero_division``). Plus, if + ``zero_division != "warn"`` raises a warning. + The metric, modifier and average arguments are used only for determining + an appropriate warning. + """ + mask = denominator == 0.0 + denominator = denominator.copy() + denominator[mask] = 1 # avoid infs/nans + result = numerator / denominator + + if not np.any(mask): + return result + + # if ``zero_division=1``, set those with denominator == 0 equal to 1 + result[mask] = 0.0 if zero_division in ['warn', 0] else 1.0 + + # the user will be removing warnings if zero_division is set to something + # different than its default value. If we are computing only f-score + # the warning will be raised only if precision and recall are ill-defined + if zero_division != 'warn' or metric not in warn_for: + return result + + # build appropriate warning + # E.g. "Precision and F-score are ill-defined and being set to 0.0 in + # labels with no predicted samples. Use ``zero_division`` parameter to + # control this behavior." + + if metric in warn_for and 'f-score' in warn_for: + msg_start = '{0} and F-score are'.format(metric.title()) + elif metric in warn_for: + msg_start = '{0} is'.format(metric.title()) + elif 'f-score' in warn_for: + msg_start = 'F-score is' + else: + return result + + NoteSequenceEvaluation._warn_prf(average, modifier, msg_start, len(result)) + + return result + + @staticmethod + def _warn_prf(average, modifier, msg_start, result_size): + axis0, axis1 = 'sample', 'label' + if average == 'samples': + axis0, axis1 = axis1, axis0 + msg = ('{0} ill-defined and being set to 0.0 {{0}} ' + 'no {1} {2}s. Use `zero_division` parameter to control' + ' this behavior.'.format(msg_start, modifier, axis0)) + if result_size == 1: + msg = msg.format('due to') + else: + msg = msg.format('in {0}s with'.format(axis1)) + warnings.warn(msg, UndefinedMetricWarning, stacklevel=2) + + @staticmethod + def check_consistent_length( + note_tokens: Sequence[Sequence[Mapping[str, Union[str, int]]]], + note_predictions: Sequence[Sequence[str]] + ): + """ + Check that all arrays have consistent first and second dimensions. + Checks whether all objects in arrays have the same shape or length. + Args: + y_true : 2d array. + y_pred : 2d array. + """ + len_tokens = list(map(len, note_tokens)) + len_predictions = list(map(len, note_predictions)) + + if len(note_tokens) != len(note_predictions) or len_tokens != len_predictions: + message = 'Found input variables with inconsistent numbers of samples:\n{}\n{}'.format(len_tokens, + len_predictions) + raise ValueError(message) diff --git a/sequence_tagging/evaluation/note_evaluation/note_token_evaluation.py b/sequence_tagging/evaluation/note_evaluation/note_token_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..081a6b609170af589bbad6fd2a728678f8c183dd --- /dev/null +++ b/sequence_tagging/evaluation/note_evaluation/note_token_evaluation.py @@ -0,0 +1,134 @@ +from collections import Counter +from typing import Sequence, List, Tuple, Union, Type, Optional + +from seqeval.reporters import DictReporter +from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix + + +class NoteTokenEvaluation(object): + """ + This class is used to evaluate token level precision, recall and F1 scores. + Script to evaluate at a token level. Calculate precision, recall, and f1 metrics + at the token level rather than the span level. + """ + + @staticmethod + def unpack_nested_list(nested_list: Sequence[Sequence[str]]) -> List[str]: + """ + Use this function to unpack a nested list and also for token level evaluation we dont + need to consider the B, I prefixes (depending on the NER notation, so remove that as well. + Args: + nested_list (Sequence[Sequence[str]]): A nested list of predictions/labels + Returns: + (List[str]): Unpacked nested list of predictions/labels + """ + return [inner if inner == 'O' else inner[2:] for nested in nested_list for inner in nested] + + @staticmethod + def get_counts(sequence: Sequence[str], ner_types: Sequence[str]) -> List[int]: + """ + Use this function to get the counts for each NER type + Args: + ner_list (Sequence[str]): A list of the NER labels/predicitons + Returns: + (List[int]): Position 0 contains the counts for the NER type that corresponds to position 0 + """ + counts = Counter() + counts.update(sequence) + return [counts[ner_type] for ner_type in ner_types] + + @staticmethod + def precision_recall_fscore( + labels: Sequence[str], + predictions: Sequence[str], + ner_types: Sequence[str], + average: Optional[str] = None + ) -> Tuple[Union[float, List[float]], Union[float, List[float]], Union[float, List[float]], Union[int, List[int]]]: + """ + Use this function to get the token level precision, recall and fscore. Internally we use the + sklearn precision_score, recall_score and f1 score functions. Also return the count of each + NER type. + Args: + labels (Sequence[str]): A list of the gold standard NER labels + predictions (Sequence[str]): A list of the predicted NER labels + average (Optional[str]): None for per NER types scores, or pass an appropriate average value + Returns: + eval_precision (Union[float, List[float]]): precision score (averaged or per ner type) + eval_precision (Union[float, List[float]]): recall score (averaged or per ner type) + eval_precision (Union[float, List[float]]): F1 score (averaged or per ner type) + counts (Union[int, List[int]]): Counts (total or per ner type) + """ + eval_precision = precision_score(y_true=labels, y_pred=predictions, labels=ner_types, average=average) + eval_recall = recall_score(y_true=labels, y_pred=predictions, labels=ner_types, average=average) + eval_f1 = f1_score(y_true=labels, y_pred=predictions, labels=ner_types, average=average) + counts = NoteTokenEvaluation.get_counts(sequence=labels, ner_types=ner_types) + if (average == None): + eval_precision = list(eval_precision) + eval_recall = list(eval_recall) + eval_f1 = list(eval_f1) + else: + counts = sum(counts) + return eval_precision, eval_recall, eval_f1, counts + + @staticmethod + def get_confusion_matrix(labels: Sequence[str], predictions: Sequence[str], ner_types: Sequence[str]): + """ + Use this function to get the token level precision, recall and fscore per NER type + and also the micro, macro and weighted averaged precision, recall and f scores. + Essentially we return a classification report + Args: + labels (Sequence[str]): A list of the gold standard NER labels + predictions (Sequence[str]): A list of the predicted NER labels + Returns: + (Type[DictReporter]): Classification report + """ + labels = NoteTokenEvaluation.unpack_nested_list(labels) + predictions = NoteTokenEvaluation.unpack_nested_list(predictions) + return confusion_matrix(y_true=labels, y_pred=predictions, labels=ner_types + ['O', ]) + + @staticmethod + def classification_report( + labels: Sequence[Sequence[str]], + predictions: Sequence[Sequence[str]], + ner_types: Sequence[str] + ) -> Type[DictReporter]: + """ + Use this function to get the token level precision, recall and fscore per NER type + and also the micro, macro and weighted averaged precision, recall and f scores. + Essentially we return a classification report which contains all this information + Args: + labels (Sequence[Sequence[str]]): A list of the gold standard NER labels for each note + predictions (Sequence[Sequence[str]]): A list of the predicted NER labels for each note + Returns: + (Type[DictReporter]): Classification report that contains the token level metric scores + """ + # Unpack the nested lists (labels and predictions) before running the evaluation metrics + labels = NoteTokenEvaluation.unpack_nested_list(nested_list=labels) + predictions = NoteTokenEvaluation.unpack_nested_list(nested_list=predictions) + # Store results in this and return this object + reporter = DictReporter() + # Calculate precision, recall and f1 for each NER type + eval_precision, eval_recall, eval_f1, counts = NoteTokenEvaluation.precision_recall_fscore( + labels=labels, + predictions=predictions, + ner_types=ner_types, + average=None + ) + # Store the results + for row in zip(ner_types, eval_precision, eval_recall, eval_f1, counts): + reporter.write(*row) + reporter.write_blank() + # Calculate the overall precision, recall and f1 - based on the defined averages + average_options = ('micro', 'macro', 'weighted') + for average in average_options: + eval_precision, eval_recall, eval_f1, counts = NoteTokenEvaluation.precision_recall_fscore( + labels=labels, + predictions=predictions, + ner_types=ner_types, + average=average + ) + # Store the results + reporter.write('{} avg'.format(average), eval_precision, eval_recall, eval_f1, counts) + reporter.write_blank() + # Return the token level results + return reporter.report() diff --git a/sequence_tagging/evaluation/note_evaluation/violations.py b/sequence_tagging/evaluation/note_evaluation/violations.py new file mode 100644 index 0000000000000000000000000000000000000000..d01aa1a656284f29e393153fae26f3f6b88e4256 --- /dev/null +++ b/sequence_tagging/evaluation/note_evaluation/violations.py @@ -0,0 +1,99 @@ +# Get the number of violation in the predicted output +from typing import NoReturn, Sequence, Tuple + + +class Violations(object): + """ + This class is used to compute the violations in the predictions + A violation is something like i.e., how many times `I-TYPE` follows `O` + or a tag of a different type. + """ + + @staticmethod + def get_prefixes(notation: str) -> Tuple[str, str, str, str]: + """ + Initialize variables that are used to check for violations based on the notation + Args: + notation (str): The NER labelling scheme + Returns: + prefix_single, prefix_begin, prefix_inside, prefix_end, prefix_outside (Tuple[str, str, str, str]): The prefixes in + the labels based + on the labelling + scheme + """ + # Define the variables that represent the tags based on the notation + if notation == 'BIO': + prefix_single = 'B' + prefix_begin = 'B' + prefix_inside = 'I' + prefix_end = 'I' + prefix_outside = 'O' + elif notation == 'BIOES': + prefix_single = 'S' + prefix_begin = 'B' + prefix_inside = 'I' + prefix_end = 'E' + prefix_outside = 'O' + elif notation == 'BILOU': + prefix_single = 'U' + prefix_begin = 'B' + prefix_inside = 'I' + prefix_end = 'L' + prefix_outside = 'O' + elif notation == 'IO': + prefix_single = 'I' + prefix_begin = 'I' + prefix_inside = 'I' + prefix_end = 'I' + prefix_outside = 'O' + else: + raise ValueError('Invalid Notation') + return prefix_single, prefix_begin, prefix_inside, prefix_end, prefix_outside + + @staticmethod + def get_violations(tag_sequence: Sequence[str], notation: str) -> int: + """ + Compute the violations in the predictions/labels + A violation is something like i.e., how many times `I-TYPE` follows `O` + or a tag of a different type. + Args: + tag_sequence (Sequence[str]): The predictions/labels (e.g O, B-DATE, I-AGE) + notation (str): The NER labelling scheme + Returns: + count (int): The number of violations + """ + prefix_single, prefix_begin, prefix_inside, prefix_end, prefix_outside = Violations.get_prefixes( + notation=notation) + count = 0 + start_tag = None + prev_tag_type = prefix_single + for tag in tag_sequence: + tag_split = tag.split('-') + # Check if the current tag is the beginning of a span or is a unit span (span of 1 token) + if tag_split[0] in [prefix_begin, prefix_single]: + # If the previous tag is not O, END (E,L) or UNIT (S, U) then it is a violation + # Since this span started and the previous span did not end + if prev_tag_type not in [prefix_outside, prefix_end, prefix_single]: + count += 1 + start_tag = tag_split[1] + prev_tag_type = tag_split[0] + # Check if the current tag is the inside/end of a span + # If it is preceeded by the O tag - then it is a violation - because this span + # does not have a begin tag (B) + elif tag_split[0] in [prefix_inside, prefix_end] and prev_tag_type == prefix_outside: + count += 1 + start_tag = tag_split[1] + prev_tag_type = tag_split[0] + # Check if the current tag is the inside/end of a span - if the type of the span + # is different then it is a violation. E.g DATE followed by AGE when the DATE tag has not ended + elif tag_split[0] in [prefix_inside, prefix_end] and prev_tag_type != prefix_outside: + if prev_tag_type not in [prefix_inside, prefix_begin]: + count += 1 + elif tag_split[1] != start_tag: + count += 1 + start_tag = tag_split[1] + prev_tag_type = tag_split[0] + else: + start_tag = None + prev_tag_type = prefix_outside + return count diff --git a/sequence_tagging/evaluation/results/__init__.py b/sequence_tagging/evaluation/results/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fd34bfeab6673b5a4bd62944b849d3faa9e5e8ca --- /dev/null +++ b/sequence_tagging/evaluation/results/__init__.py @@ -0,0 +1,2 @@ +from .results_formatter import ResultsFormatter +__all__=["ResultsFormatter"] \ No newline at end of file diff --git a/sequence_tagging/evaluation/results/results_formatter.py b/sequence_tagging/evaluation/results/results_formatter.py new file mode 100644 index 0000000000000000000000000000000000000000..d1b230da5ef7ba05da7bfe40a86864fe52bfcf5f --- /dev/null +++ b/sequence_tagging/evaluation/results/results_formatter.py @@ -0,0 +1,64 @@ +import re +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from sklearn.metrics import ConfusionMatrixDisplay + +class ResultsFormatter(object): + + @staticmethod + def get_results_df(results): + def change_column_names(group): + group.rename(columns=lambda name: re.sub('(.*_)(?=[a-zA-Z0-9]+$)', '', name), inplace=True) + return group + results_df = pd.DataFrame([results]) + group_pattern = '(.*(?=_recall|_precision|_f1|_number))' + grouped = results_df.groupby(results_df.columns.str.extract(group_pattern, expand=False), axis=1) + grouped_df_dict = {name:change_column_names(group) for name, group in grouped} + grouped_df = pd.concat(grouped_df_dict.values(), axis=1, keys=grouped_df_dict.keys()) + return grouped_df.T.unstack().droplevel(level=0, axis=1)[['precision', 'recall', 'f1', 'number']] + + @staticmethod + def get_confusion_matrix(confusion_matrix, ner_types): + S = 15 + normalize = True + title = 'Confusion Matrix' + cmap=plt.cm.Blues + classes = ner_types + ['O', ] + plt.figure(figsize = (S, S)) + + cm = confusion_matrix + cmbk = cm + cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] + + fig, ax = plt.subplots(figsize=(S, S*0.8)) + im = ax.imshow(cm, interpolation='nearest', cmap=cmap) + ax.figure.colorbar(im, ax=ax) + # We want to show all ticks... + ax.set(xticks=np.arange(0,cm.shape[1]), + yticks=np.arange(0,cm.shape[0]), + # ... and label them with the respective list entries + xticklabels=classes, yticklabels=classes, + title=title, + ylabel='Ground Truth', + xlabel='Predicted') + ax.xaxis.get_label().set_fontsize(16) + ax.yaxis.get_label().set_fontsize(16) + ax.title.set_size(16) + ax.tick_params(axis = 'both', which = 'major', labelsize = 14) + + # Rotate the tick labels and set their alignment. + plt.setp(ax.get_xticklabels(), rotation=45, ha="right", + rotation_mode="anchor") + + # Loop over data dimensions and create text annotations. + fmt = '.2f' if normalize else 'd'#'.2f' + fmt='d' + thresh = cm.max() / 2. + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + ax.text(j, i, format(cmbk[i, j], fmt), + ha="center", va="center", + color="white" if cm[i, j] > thresh else "black",fontsize=12) + fig.tight_layout() + return fig \ No newline at end of file diff --git a/sequence_tagging/models/__init__.py b/sequence_tagging/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sequence_tagging/models/__pycache__/__init__.cpython-37.pyc b/sequence_tagging/models/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ee67f0abfbcde505dd18c45de6b20495068d26c Binary files /dev/null and b/sequence_tagging/models/__pycache__/__init__.cpython-37.pyc differ diff --git a/sequence_tagging/models/hf/__init__.py b/sequence_tagging/models/hf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f5bd38ebc372be9bcce6c454144accc3b43f5594 --- /dev/null +++ b/sequence_tagging/models/hf/__init__.py @@ -0,0 +1,2 @@ +from .model_picker import ModelPicker +__all__=["ModelPicker"] \ No newline at end of file diff --git a/sequence_tagging/models/hf/__pycache__/__init__.cpython-37.pyc b/sequence_tagging/models/hf/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..375d45adaa85a780e954991a5ff539f095d9571a Binary files /dev/null and b/sequence_tagging/models/hf/__pycache__/__init__.cpython-37.pyc differ diff --git a/sequence_tagging/models/hf/__pycache__/model_picker.cpython-37.pyc b/sequence_tagging/models/hf/__pycache__/model_picker.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b8d3c646bdef7b0cde879859e82df913f311182 Binary files /dev/null and b/sequence_tagging/models/hf/__pycache__/model_picker.cpython-37.pyc differ diff --git a/sequence_tagging/models/hf/crf/__init__.py b/sequence_tagging/models/hf/crf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5168850ab3712406b0af83eef9627a430301ba2e --- /dev/null +++ b/sequence_tagging/models/hf/crf/__init__.py @@ -0,0 +1,2 @@ +from .crf_bert_model_for_token_classification import CRFBertModelForTokenClassification +__all__=["CRFBertModelForTokenClassification"] \ No newline at end of file diff --git a/sequence_tagging/models/hf/crf/__pycache__/__init__.cpython-37.pyc b/sequence_tagging/models/hf/crf/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1096fa96c6c713ec0567f3ddae54b5ebf56442dc Binary files /dev/null and b/sequence_tagging/models/hf/crf/__pycache__/__init__.cpython-37.pyc differ diff --git a/sequence_tagging/models/hf/crf/__pycache__/conditional_random_field_sub.cpython-37.pyc b/sequence_tagging/models/hf/crf/__pycache__/conditional_random_field_sub.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77573b03d5eac6f429bc6d0da8f84f54f9905dbc Binary files /dev/null and b/sequence_tagging/models/hf/crf/__pycache__/conditional_random_field_sub.cpython-37.pyc differ diff --git a/sequence_tagging/models/hf/crf/__pycache__/crf_bert_model_for_token_classification.cpython-37.pyc b/sequence_tagging/models/hf/crf/__pycache__/crf_bert_model_for_token_classification.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99f608171e5fd773708755ac51ff3f4cd0b98184 Binary files /dev/null and b/sequence_tagging/models/hf/crf/__pycache__/crf_bert_model_for_token_classification.cpython-37.pyc differ diff --git a/sequence_tagging/models/hf/crf/conditional_random_field_sub.py b/sequence_tagging/models/hf/crf/conditional_random_field_sub.py new file mode 100644 index 0000000000000000000000000000000000000000..3b8d24f8c033248331a4a088c29cd5d89fb62ee8 --- /dev/null +++ b/sequence_tagging/models/hf/crf/conditional_random_field_sub.py @@ -0,0 +1,44 @@ +import torch +from typing import List, Tuple, NoReturn + +from allennlp.modules import ConditionalRandomField + + +class ConditionalRandomFieldSub(ConditionalRandomField): + """ + Implement a CRF layer + The code is borrowed from allennlp, We could have used it directly but we had + to subclass since using the code directly was throwing an error saying the mask + tensor could not be found on the GPU. So we subclass and it put the mask tensor + on the right device. Refer to allennlp for more details + """ + + def __init__(self, num_labels: int, constraints: List[Tuple[int, int]]) -> NoReturn: + """ + Initialize the allennlp class with the number of labels and constraints + Args: + num_labels (int): The number of possible tags/labels (B-AGE, I-DATE, etc) + constraints (List[Tuple[int, int]): Are there any constraints for certain tag transitions. For example + dont allow transitions from B-DATE to I-MRN etc + """ + super().__init__(num_labels, constraints) + + def forward(self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor = None) -> torch.Tensor: + """ + Computes the log likelihood. + The only change we make is moving the mask tensor to the same device as the inputs + Args: + inputs (torch.Tensor): Model logits + tags (torch.Tensor): True labels + mask (torch.BoolTensor): Mask + """ + if mask is None: + mask = torch.ones(*tags.size(), dtype=torch.bool, device=inputs.device) + else: + # The code below fails in weird ways if this isn't a bool tensor, so we make sure. + mask = mask.to(torch.bool, device=inputs.device) + # Compute the CRF loss + log_denominator = self._input_likelihood(inputs, mask) + log_numerator = self._joint_likelihood(inputs, tags, mask) + # Return crf loss + return torch.sum(log_numerator - log_denominator) diff --git a/sequence_tagging/models/hf/crf/crf_bert_model_for_token_classification.py b/sequence_tagging/models/hf/crf/crf_bert_model_for_token_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..376b0aa30b86ce7e9604bcd1dba410e16ebe3476 --- /dev/null +++ b/sequence_tagging/models/hf/crf/crf_bert_model_for_token_classification.py @@ -0,0 +1,85 @@ +from transformers import ( + BertConfig, + BertForTokenClassification, +) + +from .conditional_random_field_sub import ConditionalRandomFieldSub +from .crf_token_classifier_output import CRFTokenClassifierOutput + + +class CRFBertModelForTokenClassification(BertForTokenClassification): + def __init__( + self, + config: BertConfig, + crf_constraints + ): + super().__init__(config) + self.crf = ConditionalRandomFieldSub(num_labels=config.num_labels, constraints=crf_constraints) + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - + 1]``. + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # Or we use self.base_model - might work with auto model class + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + batch_size = logits.shape[0] + sequence_length = logits.shape[1] + loss = None + if labels is not None: + # Negative of the log likelihood. + # Loop through the batch here because of 2 reasons: + # 1- the CRF package assumes the mask tensor cannot have interleaved + # zeros and ones. In other words, the mask should start with True + # values, transition to False at some moment and never transition + # back to True. That can only happen for simple padded sequences. + # 2- The first column of mask tensor should be all True, and we + # cannot guarantee that because we have to mask all non-first + # subtokens of the WordPiece tokenization. + loss = 0 + for seq_logits, seq_labels in zip(logits, labels): + # Index logits and labels using prediction mask to pass only the + # first subtoken of each word to CRF. + seq_mask = seq_labels != -100 + seq_logits_crf = seq_logits[seq_mask].unsqueeze(0) + seq_labels_crf = seq_labels[seq_mask].unsqueeze(0) + loss -= self.crf(inputs=seq_logits_crf, tags=seq_labels_crf) + loss /= batch_size + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + return CRFTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/sequence_tagging/models/hf/crf/crf_token_classifier_output.py b/sequence_tagging/models/hf/crf/crf_token_classifier_output.py new file mode 100644 index 0000000000000000000000000000000000000000..490d472ced6d0f3bceb06e524bd4b3077eb011bd --- /dev/null +++ b/sequence_tagging/models/hf/crf/crf_token_classifier_output.py @@ -0,0 +1,16 @@ +import torch +from dataclasses import dataclass +from transformers.modeling_outputs import TokenClassifierOutput + + +@dataclass +class CRFTokenClassifierOutput(TokenClassifierOutput): + """ + The default TokenClassifierOutput returns logits, loss, hidden_states and attentions + when we use the CRF module, we want the model.forward function to return the predicted + sequence from the CRF module. So we introduce this class which subclasses TokenClassifierOutput + and additionally returns the predictions tensor - which contains the sequences + training examples. + """ + predictions: torch.LongTensor = None + scores: torch.LongTensor = None diff --git a/sequence_tagging/models/hf/crf/utils.py b/sequence_tagging/models/hf/crf/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3a018c3c1205c88404dcfac0f3c1c597cb46cb2b --- /dev/null +++ b/sequence_tagging/models/hf/crf/utils.py @@ -0,0 +1,144 @@ +# utils for the CRF class - code is borrowed from allenlp and we make a minor change +# refer to allennlp for further details +from typing import List, Tuple, Dict +from allennlp.common.checks import ConfigurationError + + +def allowed_transitions(constraint_type: str, labels: Dict[int, str]) -> List[Tuple[int, int]]: + """ + Given labels and a constraint type, returns the allowed transitions. It will + additionally include transitions for the start and end states, which are used + by the conditional random field. + # Parameters + constraint_type : `str`, required + Indicates which constraint to apply. Current choices are + "BIO", "IOB1", "BIOUL", and "BMES". + labels : `Dict[int, str]`, required + A mapping {label_id -> label}. Most commonly this would be the value from + Vocabulary.get_index_to_token_vocabulary() + # Returns + `List[Tuple[int, int]]` + The allowed transitions (from_label_id, to_label_id). + """ + num_labels = len(labels) + start_tag = num_labels + end_tag = num_labels + 1 + labels_with_boundaries = list(labels.items()) + [(start_tag, "START"), (end_tag, "END")] + + allowed = [] + for from_label_index, from_label in labels_with_boundaries: + if from_label in ("START", "END"): + from_tag = from_label + from_entity = "" + else: + from_tag = from_label[0] + from_entity = from_label[1:] + for to_label_index, to_label in labels_with_boundaries: + if to_label in ("START", "END"): + to_tag = to_label + to_entity = "" + else: + to_tag = to_label[0] + to_entity = to_label[1:] + if is_transition_allowed(constraint_type, from_tag, from_entity, to_tag, to_entity): + allowed.append((from_label_index, to_label_index)) + return allowed + + +def is_transition_allowed( + constraint_type: str, from_tag: str, from_entity: str, to_tag: str, to_entity: str +): + """ + Given a constraint type and strings `from_tag` and `to_tag` that + represent the origin and destination of the transition, return whether + the transition is allowed under the given constraint type. + # Parameters + constraint_type : `str`, required + Indicates which constraint to apply. Current choices are + "BIO", "IOB1", "BIOUL", and "BMES". + from_tag : `str`, required + The tag that the transition originates from. For example, if the + label is `I-PER`, the `from_tag` is `I`. + from_entity : `str`, required + The entity corresponding to the `from_tag`. For example, if the + label is `I-PER`, the `from_entity` is `PER`. + to_tag : `str`, required + The tag that the transition leads to. For example, if the + label is `I-PER`, the `to_tag` is `I`. + to_entity : `str`, required + The entity corresponding to the `to_tag`. For example, if the + label is `I-PER`, the `to_entity` is `PER`. + # Returns + `bool` + Whether the transition is allowed under the given `constraint_type`. + """ + + if to_tag == "START" or from_tag == "END": + # Cannot transition into START or from END + return False + + if constraint_type == "BIOUL" or constraint_type == "BILOU": + if from_tag == "START": + # return to_tag in ("O", "B", "U") + return to_tag in ("O", "B", "U", "I", "L") + if to_tag == "END": + # return from_tag in ("O", "L", "U") + return from_tag in ("O", "B", "U", "I", "L") + return any( + [ + # O can transition to O, B-* or U-* + # L-x can transition to O, B-*, or U-* + # U-x can transition to O, B-*, or U-* + from_tag in ("O", "L", "U") and to_tag in ("O", "B", "U"), + # B-x can only transition to I-x or L-x + # I-x can only transition to I-x or L-x + from_tag in ("B", "I") and to_tag in ("I", "L") and from_entity == to_entity, + ] + ) + elif constraint_type == "BIO": + if from_tag == "START": + # return to_tag in ("O", "B") + return to_tag in ("O", "B", "I") + if to_tag == "END": + return from_tag in ("O", "B", "I") + return any( + [ + # Can always transition to O or B-x + to_tag in ("O", "B"), + # Can only transition to I-x from B-x or I-x + to_tag == "I" and from_tag in ("B", "I") and from_entity == to_entity, + ] + ) + elif constraint_type == "IOB1": + if from_tag == "START": + return to_tag in ("O", "I") + if to_tag == "END": + return from_tag in ("O", "B", "I") + return any( + [ + # Can always transition to O or I-x + to_tag in ("O", "I"), + # Can only transition to B-x from B-x or I-x, where + # x is the same tag. + to_tag == "B" and from_tag in ("B", "I") and from_entity == to_entity, + ] + ) + elif constraint_type == "BMES": + if from_tag == "START": + return to_tag in ("B", "S") + if to_tag == "END": + return from_tag in ("E", "S") + return any( + [ + # Can only transition to B or S from E or S. + to_tag in ("B", "S") and from_tag in ("E", "S"), + # Can only transition to M-x from B-x, where + # x is the same tag. + to_tag == "M" and from_tag in ("B", "M") and from_entity == to_entity, + # Can only transition to E-x from B-x or M-x, where + # x is the same tag. + to_tag == "E" and from_tag in ("B", "M") and from_entity == to_entity, + ] + ) + else: + raise ConfigurationError(f"Unknown constraint type: {constraint_type}") diff --git a/sequence_tagging/models/hf/model_picker.py b/sequence_tagging/models/hf/model_picker.py new file mode 100644 index 0000000000000000000000000000000000000000..635282df94dc8e61f0c7078638481848dca601c1 --- /dev/null +++ b/sequence_tagging/models/hf/model_picker.py @@ -0,0 +1,84 @@ +# Use the functions below to get the desired model for training. +from typing import Dict, NoReturn + +from transformers import AutoConfig, AutoModelForTokenClassification + +from .crf.utils import allowed_transitions +from .crf import CRFBertModelForTokenClassification + + +class ModelPicker(object): + """ + This class is used to pick the model we are using to train. + The class provides functions that returns the desired model objects + i.e get the desired model for training etc + """ + + def __init__( + self, + model_name_or_path: str, + config: AutoConfig, + cache_dir: str, + model_revision: str, + use_auth_token: bool + ) -> NoReturn: + """ + Initialize the variables needed for loading the huggingface models + Args: + model_name_or_path (str): Path to pretrained model or model identifier from huggingface.co/models + config (AutoConfig): Pretrained config object + cache_dir (str): Where do you want to store the pretrained models downloaded from huggingface.co + model_revision (str): The specific model version to use (can be a branch name, tag name or commit id). + use_auth_token (bool): Will use the token generated when running `transformers-cli login` + (necessary to use this script with private models). + """ + self._model_name_or_path = model_name_or_path + self._config = config + self._cache_dir = cache_dir + self._model_revision = model_revision + self._use_auth_token = use_auth_token + + def get_argmax_bert_model(self) -> AutoModelForTokenClassification: + """ + Return a model that uses argmax to process the model logits for obtaining the predictions + and calculating the loss + Returns: + (AutoModelForTokenClassification): Return argmax token classification model + """ + return AutoModelForTokenClassification.from_pretrained( + self._model_name_or_path, + from_tf=bool(".ckpt" in self._model_name_or_path), + config=self._config, + cache_dir=self._cache_dir, + revision=self._model_revision, + use_auth_token=self._use_auth_token, + ) + + def get_crf_bert_model( + self, + notation: str, + id_to_label: Dict[int, str] + ) -> CRFBertModelForTokenClassification: + """ + Return a model that uses crf to process the model logits for obtaining the predictions + and calculating the loss. Set the CRF constraints based on the notation and the labels. + For example - B-DATE, I-LOC is not valid, since we add the constraint that the I-LOC + label cannot follow the B-DATE label and that only the I-DATE label can follow the B-DATE label + Args: + notation (str): The NER notation - e.g BIO, BILOU + id_to_label (Mapping[int, str]): Mapping between the NER label ID and the NER label + Returns: + (CRFBertModelForTokenClassification): Return crf token classification model + """ + constraints = { + 'crf_constraints': allowed_transitions(constraint_type=notation, labels=id_to_label) + } + return CRFBertModelForTokenClassification.from_pretrained( + self._model_name_or_path, + from_tf=bool(".ckpt" in self._model_name_or_path), + config=self._config, + cache_dir=self._cache_dir, + revision=self._model_revision, + use_auth_token=self._use_auth_token, + **constraints, + ) diff --git a/sequence_tagging/note_aggregate/__init__.py b/sequence_tagging/note_aggregate/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2940fa09e4c3634505c366ecbe4cccc77093cf19 --- /dev/null +++ b/sequence_tagging/note_aggregate/__init__.py @@ -0,0 +1,2 @@ +from .note_level_aggregator import NoteLevelAggregator +__all__ = ["NoteLevelAggregator"] \ No newline at end of file diff --git a/sequence_tagging/note_aggregate/note_level_aggregator.py b/sequence_tagging/note_aggregate/note_level_aggregator.py new file mode 100644 index 0000000000000000000000000000000000000000..215b150f87154089af762c9351c189daa7d3ed3e --- /dev/null +++ b/sequence_tagging/note_aggregate/note_level_aggregator.py @@ -0,0 +1,128 @@ +from collections import defaultdict +from typing import Sequence, Mapping, NoReturn, List, Union + + +class NoteLevelAggregator(object): + """ + The input while training the model is at a sentence level. What happens is we + have a bunch of notes (say 20) which we split into sentences and tokenize, so + we end up with tokenized sentences (say 400). Each sentence is then used as a + training example. Now this list of sentences is shuffled and the model is trained. + For evaluation and prediction however we want to know which sentence belong to + which note since we want go back from the sentence to the note level. This class + basically aggregates sentence level information back to the note level. So that we + can do evaluation at the note level and get aggregate predictions for the entire note. + To perform this we keep track of all the note_ids - [ID1, ID2 ...]. We sue this list + as a reference - so when we return predictions we return a list [[P1], [P2] ..] where + P1 corresponds to the predictions for the note with note id ID1. + """ + + def __init__( + self, + note_ids: Sequence[str], + note_sent_info: Sequence[Mapping[str, Union[str, int]]] + ) -> NoReturn: + """ + Initialize the reference note_ids, this list and the position of the note_id in this list + is used as reference when aggregating predictions/tokens belonging to a note. + The note_ids are used as references for the functions below. + Args: + note_ids (Sequence[str]): The sequence of note_ids to use as reference + note_sent_info (Sequence[Mapping[str, Union[str, int]]]): The information for each sentence + (training example) it contains which note_id + the sentence belongs to and the start and end + position of that sentence in the note + """ + self._note_ids = note_ids + self._note_index_map = self.__get_note_index_map(note_sent_info) + check_len = len([index for note_index in self._note_index_map for index in note_index]) + check_len_unique = len(set([index for note_index in self._note_index_map for index in note_index])) + if len(note_sent_info) != check_len or check_len != check_len_unique: + raise ValueError('Length mismatch') + + @staticmethod + def __get_note_aggregate(note_sent_info: Sequence[Mapping[str, Union[str, int]]]) -> defaultdict(list): + """ + Return a mapping where the key is the note_id and the value is a sequence that + contain the sentence information. For example 'id1':[{index=8, start:0, end:30}, + {index=80, start:35, end:70}, {index=2, start:71, end:100} ..] + What this mapping is saying that for this note_id, the first sentence in the note + is the 8th sentence in the dataset, the second sentence in the note is the 80th + sentence in the dataset and the third sentence is the 2nd sentence in the dataset. + This is because the dataset can be shuffled. + Args: + note_sent_info (Sequence[Mapping[str, Union[str, int]]]): The information for each sentence + (training example) it contains which note_id + the sentence belongs to and the start and end + position of that sentence in the note + Returns: + note_aggregate (defaultdict(list)): Contains the note_id to sentence (train example) + mapping with respect to its position with the dataset + """ + note_aggregate = defaultdict(list) + for index, note_sent in enumerate(note_sent_info): + note_id = note_sent['note_id'] + start = note_sent['start'] + end = note_sent['end'] + note_aggregate[note_id].append({'index': index, 'start': start, 'end': end}) + # Sort the sentences/training example based on its start position in the note + for note_id, aggregate_info in note_aggregate.items(): + aggregate_info.sort(key=lambda info: info['start']) + return note_aggregate + + def __get_note_index_map(self, note_sent_info: Sequence[Mapping[str, Union[str, int]]]) -> List[List[int]]: + """ + Return a sequence that contains a sequence within which contains the sentence position w.r.t to the dataset. + for that note (the note being note_id_1 for position 1) + For example we have note_ids=[i1, i2, i3, ...] + This function will return [[8, 80, 2 ..], [7, 89, 9], [1, 3, 5, ...] + Where position 1 corresponds to ID - i1 and we say that the 8th, 80th and 2nd + sentence in the dataset correspond to the sentences in the note i1 (in sorted order). + For position 2, its ID - i2 and we say that the 7, 89, 9 sentence (training example) + in the dataset correspond to the sentences in the note i2 (in sorted order). + Remember the dataset can be shuffled. + Args: + note_sent_info (Sequence[Mapping[str, Union[str, int]]]): The information for each sentence + (training example) it contains which note_id + the sentence belongs to and the start and end + position of that sentence in the note + Returns: + List[List[int]]: Return a sequence that contains a sequence within which contains + the sentence position w.r.t to the dataset for that note + (the note being note_id_1 for position 1) + """ + note_aggregate = NoteLevelAggregator.__get_note_aggregate(note_sent_info) + return [[note_agg['index'] for note_agg in note_aggregate.get(note_id, None)] for note_id in self._note_ids] + + def get_aggregate_sequences( + self, + sequences: Union[Sequence[Sequence[str]], Sequence[Sequence[Mapping[str, Union[str, int]]]]] + ) -> List[List[str]]: + """ + Return a sequence that contains a sequence within which contains the tokens or labels. + for that note (the note being note_id_1 for position 1) + For example we have note_ids=[i1, i2, i3, ...] + This function will return [[PREDICTIONS -i1 ...], [PREDICTIONS -i2 ...], [PREDICTIONS -i3 ...] + Where position 1 corresponds to ID - i1 and it contains the following predictions + that are present in the note i1 (in sorted order). + Where position 2 corresponds to ID - i2 and it contains the following predictions + that are present in the note i2 (in sorted order). + Return a sequence that contains a sequence within which contains the sentence position w.r.t to the dataset. + for that note (the note being note_id_1 for position 1) + For example we have note_ids=[i1, i2, i3, ...] + This function will return [[8, 80, 2 ..], [7, 89, 9], [1, 3, 5, ...] + Where position 1 corresponds to ID - i1 and we say that the 8th, 80th and 2nd + sentence in the dataset correspond to the sentences in the note i1 (in sorted order). + For position 2, its ID - i2 and we say that the 7, 89, 9 sentence (training example) + in the dataset correspond to the sentences in the note i2 (in sorted order). + Remember the dataset can be shuffled. + Args: + sequences (Union[Sequence[Sequence[str]], Sequence[Sequence[Mapping[str, Union[str, int]]]]]): The sequence + of tokens or + labels + Returns: + List[List[int]]: Return a sequence that contains a sequence within which contains + the predictions for that note (the note being note_id_1 for position 1) + """ + return [[sequence for index in note_index for sequence in sequences[index]] for note_index in + self._note_index_map] diff --git a/sequence_tagging/post_process/__init__.py b/sequence_tagging/post_process/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sequence_tagging/post_process/model_outputs/__init__.py b/sequence_tagging/post_process/model_outputs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bbaf5de7fb5f94c2c5bd74acbb4a22410f8db1ae --- /dev/null +++ b/sequence_tagging/post_process/model_outputs/__init__.py @@ -0,0 +1,2 @@ +from .post_process_picker import PostProcessPicker +__all__=["PostProcessPicker"] \ No newline at end of file diff --git a/sequence_tagging/post_process/model_outputs/argmax_process.py b/sequence_tagging/post_process/model_outputs/argmax_process.py new file mode 100644 index 0000000000000000000000000000000000000000..7e71812cd446979a81706b51d273ae64dcd79457 --- /dev/null +++ b/sequence_tagging/post_process/model_outputs/argmax_process.py @@ -0,0 +1,53 @@ +import numpy as np +from typing import Sequence, NoReturn, Tuple + +from .utils import check_consistent_length + + +class ArgmaxProcess(object): + """ + Process the output of the model forward pass. The forward pass will return the predictions + (e.g the logits), labels if present. We process the output and return the processed + values based on the application. This script we return the final prediction as the + argmax of the logits. + """ + + def __init__(self, label_list: Sequence[str]) -> NoReturn: + """ + Initialize a label list where the position corresponds to a particular label. For example + position 0 will correspond to B-DATE etc. + Args: + label_list (Sequence[str]): The list of NER labels + """ + self._label_list = label_list + + def decode( + self, + predictions: Sequence[Sequence[Sequence[float]]], + labels: Sequence[Sequence[int]] + ) -> Tuple[Sequence[Sequence[str]], Sequence[Sequence[str]]]: + """ + Decode the predictions and labels so that the evaluation function and prediction + functions can use them accordingly. The predictions and labels are numbers (ids) + of the labels, these will be converted back to the NER tags (B-AGE, I-DATE etc) using + the label_list. In this function we just take the argmax of the logits (scores) of the predictions + Also remove the predictions and labels on the subword and context tokens + Args: + predictions (Sequence[Sequence[Sequence[float]]]): The logits (scores for each tag) returned by the model + labels (Sequence[Sequence[str]]): Gold standard labels + Returns: + true_predictions (Sequence[Sequence[str]]): The predicted NER tags + true_labels (Sequence[Sequence[str]]): The gold standard NER tags + """ + predictions = np.argmax(predictions, axis=2) + # Remove ignored index (special tokens) + true_predictions = [ + [self._label_list[p] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + true_labels = [ + [self._label_list[l] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + check_consistent_length(true_predictions, true_labels) + return true_predictions, true_labels diff --git a/sequence_tagging/post_process/model_outputs/crf_argmax_process.py b/sequence_tagging/post_process/model_outputs/crf_argmax_process.py new file mode 100644 index 0000000000000000000000000000000000000000..d6526b21ec6e6cb1184ae77c9841daec1dade7d8 --- /dev/null +++ b/sequence_tagging/post_process/model_outputs/crf_argmax_process.py @@ -0,0 +1,27 @@ +from typing import Sequence, NoReturn, List + +from .crf_process import CRFProcess + + +class CRFArgmaxProcess(CRFProcess): + + def __init__(self, label_list: Sequence[str], top_k: int = None) -> NoReturn: + """ + Initialize a label list where the position corresponds to a particular label. For example + position 0 will correspond to B-DATE etc. top k will return the top k CRF sequences + Args: + label_list (Sequence[str]): The list of NER labels + top_k (int): The number of top CRF sequences to return + """ + super().__init__(label_list, top_k) + + def process_sequences(self, sequences: Sequence[Sequence[str]], scores: Sequence[float]) -> List[str]: + """ + The function will get the top sequence given by the crf layer based on the CRF loss/score. + Args: + sequences (Sequence[Sequence[str]]): The list of possible sequences from the CRF layer + scores (Sequence[float]): The scores for the sequences + Returns: + (List[str]): Highest scoring sequence of tags + """ + return sequences[0] diff --git a/sequence_tagging/post_process/model_outputs/crf_process.py b/sequence_tagging/post_process/model_outputs/crf_process.py new file mode 100644 index 0000000000000000000000000000000000000000..d0158af1492430bfe9760ed6c27f0d7b456ae383 --- /dev/null +++ b/sequence_tagging/post_process/model_outputs/crf_process.py @@ -0,0 +1,99 @@ +from typing import Sequence, NoReturn + +import torch + +from .utils import check_consistent_length + + +class CRFProcess(object): + + def __init__( + self, + label_list: Sequence[str], + top_k: int + ) -> NoReturn: + """ + Initialize a label list where the position corresponds to a particular label. For example + position 0 will correspond to B-DATE etc. top k will return the top k CRF sequences + Args: + label_list (Sequence[str]): The list of NER labels + top_k (int): The number of top CRF sequences to return + """ + self._label_list = label_list + self._top_k = top_k + self._crf = None + + def set_crf(self, crf): + """ + Store the CRF layer used while training the model + Args: + crf (): Set the CRF layer - this contains the CRF weights (NER transition weights) + """ + self._crf = crf + + def process_sequences( + self, + sequences: Sequence[Sequence[str]], + scores: Sequence[float] + ) -> NoReturn: + """ + The function will be implemented by the sub class and will return a sequence of NER + predictions based on the implemented function + Args: + sequences (Sequence[Sequence[str]]): The list of possible sequences from the CRF layer + scores (Sequence[float]): The scores for the sequences + """ + raise NotImplementedError + + def decode( + self, + predictions: Sequence[Sequence[Sequence[float]]], + labels: Sequence[Sequence[int]] + ): + """ + Decode the predictions and labels so that the evaluation function and prediction + functions can use them accordingly. The predictions and labels are numbers (ids) + of the labels, these will be converted back to the NER tags (B-AGE, I-DATE etc) using + the label_list. In this function we process the CRF sequences and their scores and + select the NER sequence based on the implementation of the process_sequences function + Also remove the predictions and labels on the subword and context tokens + Args: + predictions (: Sequence[Sequence[str]]): The logits (scores for each tag) returned by the model + labels (Sequence[Sequence[str]]): Gold standard labels + Returns: + true_predictions (Sequence[Sequence[str]]): The predicted NER tags + true_labels (Sequence[Sequence[str]]): The gold standard NER tags + """ + # Check if the CRF layer has been initialized + if self._crf is None: + raise ValueError('CRF layer not initialized/set - use the set_crf function to set it') + # Convert to a torch tensor, since the CRF layer expects a torch tensor + logits = torch.tensor(predictions) + labels_tensor = torch.tensor(labels) + output_tags = list() + # Get the CRF outputs + # Process the top K outputs based and store the processed sequence + # based on process_sequences function + for seq_logits, seq_labels in zip(logits, labels_tensor): + seq_mask = seq_labels != -100 + seq_logits_crf = seq_logits[seq_mask].unsqueeze(0) + tags = self._crf.viterbi_tags(seq_logits_crf, top_k=self._top_k) + # Unpack "batch" results + if self._top_k is None: + sequences = [tag[0] for tag in tags] + scores = [tag[1] for tag in tags] + else: + sequences = [tag[0] for tag in tags[0]] + scores = [tag[1] for tag in tags[0]] + output_tags.append(self.process_sequences(sequences, scores)) + # Remove ignored index (special tokens) + true_predictions = [ + [self._label_list[p] for p in prediction] + for prediction in output_tags + ] + true_labels = [ + [self._label_list[l] for l in label if l != -100] + for label in labels + ] + check_consistent_length(true_predictions, true_labels) + return true_predictions, true_labels diff --git a/sequence_tagging/post_process/model_outputs/logits_process.py b/sequence_tagging/post_process/model_outputs/logits_process.py new file mode 100644 index 0000000000000000000000000000000000000000..3057e5fb68cdd77c75f331eac0ba467c8384ac08 --- /dev/null +++ b/sequence_tagging/post_process/model_outputs/logits_process.py @@ -0,0 +1,48 @@ +from typing import Sequence, NoReturn, Tuple + + +class LogitsProcess(object): + """ + Process the output of the model forward pass. The forward pass will return the predictions + (e.g the logits), labels if present. We process the output and return the processed + values based on the application. This script we return the final prediction as the + argmax of the logits. + """ + + def __init__(self, label_list: Sequence[str]) -> NoReturn: + """ + Initialize a label list where the position corresponds to a particular label. For example + position 0 will correspond to B-DATE etc. + Args: + label_list (Sequence[str]): The list of NER labels + """ + self._label_list = label_list + + def decode( + self, + predictions: Sequence[Sequence[Sequence[float]]], + labels: Sequence[Sequence[int]] + ) -> Tuple[Sequence[Sequence[Sequence[float]]], Sequence[Sequence[str]]]: + """ + Decode the predictions and labels so that the evaluation function and prediction + functions can use them accordingly. The predictions and labels are numbers (ids) + of the labels, these will be converted back to the NER tags (B-AGE, I-DATE etc) using + the label_list. In this function we just take the argmax of the logits (scores) of the predictions + Also remove the predictions and labels on the subword and context tokens + Args: + predictions (Sequence[Sequence[Sequence[float]]]): The logits (scores for each tag) returned by the model + labels (Sequence[Sequence[int]]): Gold standard labels + Returns: + true_predictions (Sequence[Sequence[str]]): The predicted NER tags + true_labels (Sequence[Sequence[str]]): The gold standard NER tags + """ + # Remove ignored index (special tokens) + true_predictions = [ + [[float(value) for value in p] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + true_labels = [ + [self._label_list[l] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + return true_predictions, true_labels diff --git a/sequence_tagging/post_process/model_outputs/post_process_picker.py b/sequence_tagging/post_process/model_outputs/post_process_picker.py new file mode 100644 index 0000000000000000000000000000000000000000..b6f3f482ccfd63739980aa7cc160e17be118ab5e --- /dev/null +++ b/sequence_tagging/post_process/model_outputs/post_process_picker.py @@ -0,0 +1,57 @@ +from typing import Sequence + +from .argmax_process import ArgmaxProcess +from .crf_argmax_process import CRFArgmaxProcess +from .logits_process import LogitsProcess +from .threshold_process_max import ThresholdProcessMax + + +class PostProcessPicker(object): + """ + This class is used to pick the post process layer that processed the model + logits. The class provides functions that returns the desired post processor objects + For example we can pick the argamx of the logits, where we just choose the highest scoring + tag as the prediction for a token or we can use a crf layer to pick the highest + scoring sequence of tags + """ + def __init__(self, label_list): + """ + Initialize the NER label list + Args: + label_list (Sequence[str]): The NER labels. e.g B-DATE, I-DATE, B-AGE ... + """ + self._label_list = label_list + + def get_argmax(self) -> ArgmaxProcess: + """ + Return a post processor that uses argmax to process the model logits for obtaining the predictions + Chooses the highest scoring tag. + Returns: + (ArgmaxProcess): Return argmax post processor + """ + return ArgmaxProcess(self._label_list) + + def get_crf(self) -> CRFArgmaxProcess: + """ + Return a post processor that uses CRF layer to process the model logits for obtaining the predictions + Chooses the highest scoring sequence of tags based on the CRF layer + Returns: + (CRFArgmaxProcess): Return CRF layer post processor + """ + return CRFArgmaxProcess(self._label_list) + + def get_logits(self) -> LogitsProcess: + """ + Return a post processor that returns the model logits + Returns: + (LogitsProcess): Return Logits layer post processor + """ + return LogitsProcess(self._label_list) + + def get_threshold_max(self, threshold) -> ThresholdProcessMax: + """ + Return a post processor that uses a threshold (max) to process and return the model logits + Returns: + (ThresholdProcessMax): Return Threshold Max post processor + """ + return ThresholdProcessMax(self._label_list, threshold=threshold) diff --git a/sequence_tagging/post_process/model_outputs/threshold_process_max.py b/sequence_tagging/post_process/model_outputs/threshold_process_max.py new file mode 100644 index 0000000000000000000000000000000000000000..8fc7af4d440559605c6efbdd95bac3c499de31ca --- /dev/null +++ b/sequence_tagging/post_process/model_outputs/threshold_process_max.py @@ -0,0 +1,61 @@ +from typing import Sequence, NoReturn, Tuple + +import numpy as np +from scipy.special import softmax + +from .utils import check_consistent_length + + +class ThresholdProcessMax(object): + """ + """ + + def __init__(self, label_list: Sequence[str], threshold: float) -> NoReturn: + """ + Initialize a label list where the posiion corresponds to a particular label. For example + position 0 will correspond to B-DATE etc. + Args: + label_list (Sequence[str]): The list of NER labels + """ + self._label_list = label_list + self._threshold = threshold + self._outside_label_index = self._label_list.index('O') + self._mask = np.zeros((len(self._label_list)), dtype=bool) + self._mask[self._outside_label_index] = True + + def get_masked_array(self, data): + return np.ma.MaskedArray(data=data, mask=self._mask) + + def process_prediction(self, prediction): + softmax_prob = softmax(prediction) + masked_softmax_prob = self.get_masked_array(data=softmax_prob) + max_value = masked_softmax_prob[masked_softmax_prob >= self._threshold].max() + if type(max_value) == np.ma.core.MaskedConstant: + return self._outside_label_index + else: + return masked_softmax_prob.argmax() + + def decode( + self, + predictions: Sequence[Sequence[Sequence[float]]], + labels: Sequence[Sequence[int]] + ) -> Tuple[Sequence[Sequence[str]], Sequence[Sequence[str]]]: + """ + Args: + predictions (Sequence[Sequence[Sequence[float]]]): The logits (scores for each tag) returned by the model + labels (Sequence[Sequence[int]]): Gold standard labels + Returns: + true_predictions (Sequence[Sequence[str]]): The predicted NER tags + true_labels (Sequence[Sequence[str]]): The gold standard NER tags + """ + # Remove ignored index (special tokens) + true_predictions = [ + [self._label_list[self.process_prediction(p)] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + true_labels = [ + [self._label_list[l] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + check_consistent_length(true_predictions, true_labels) + return true_predictions, true_labels diff --git a/sequence_tagging/post_process/model_outputs/utils.py b/sequence_tagging/post_process/model_outputs/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..61b6ec03b04f1910b950885b9ee6b5eeb5a88f21 --- /dev/null +++ b/sequence_tagging/post_process/model_outputs/utils.py @@ -0,0 +1,18 @@ +from typing import List + + +def check_consistent_length(y_true: List[List[str]], y_pred: List[List[str]]): + """ + Check that all arrays have consistent first and second dimensions. + Checks whether all objects in arrays have the same shape or length. + Args: + y_true : 2d array. + y_pred : 2d array. + """ + len_true = list(map(len, y_true)) + len_pred = list(map(len, y_pred)) + is_list = set(map(type, y_true)) | set(map(type, y_pred)) + + if len(y_true) != len(y_pred) or len_true != len_pred: + message = 'Found input variables with inconsistent numbers of samples:\n{}\n{}'.format(len_true, len_pred) + raise ValueError(message) diff --git a/sequence_tagging/sequence_tagger.py b/sequence_tagging/sequence_tagger.py new file mode 100644 index 0000000000000000000000000000000000000000..3febd82b246a29df7697fd350000d12ab0e4bfda --- /dev/null +++ b/sequence_tagging/sequence_tagger.py @@ -0,0 +1,568 @@ +# Train a model using the huggingface library +# The datasets have been built using the scripts in the ner_datasets folder +# these datasets will be used as input to the model. +import os +import sys +import json +import logging +from typing import Optional, Sequence + +import datasets +import transformers +from datasets import load_dataset, load_metric +from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers import ( + AutoConfig, + AutoModelForTokenClassification, + AutoTokenizer, + DataCollatorForTokenClassification, + HfArgumentParser, + PreTrainedTokenizerFast, + Trainer, + TrainingArguments, + set_seed, +) +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +#check_min_version("4.13.0.dev0") + +# require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") + +from .models.hf import ModelPicker +from .evaluation import MetricsCompute +from .note_aggregate import NoteLevelAggregator +from .post_process.model_outputs import PostProcessPicker +from .dataset_builder import DatasetTokenizer, LabelMapper, NERDataset, NERLabels +from .arguments import ModelArguments, DataTrainingArguments, EvaluationArguments + + +class SequenceTagger(object): + + def __init__( + self, + task_name, + notation, + ner_types, + model_name_or_path, + config_name: Optional[str] = None, + tokenizer_name: Optional[str] = None, + post_process: str = 'argmax', + cache_dir: Optional[str] = None, + model_revision: str = 'main', + use_auth_token: bool = False, + threshold: Optional[float] = None, + do_lower_case = False, + fp16: bool = False, + seed: int = 41, + local_rank: int = - 1 + ): + self._task_name = task_name + self._notation = notation + self._ner_types = ner_types + self._model_name_or_path = model_name_or_path + self._config_name = config_name if config_name else self._model_name_or_path + self._tokenizer_name = tokenizer_name if tokenizer_name else self._model_name_or_path + self._post_process = post_process + self._cache_dir = cache_dir + self._model_revision = model_revision + self._use_auth_token = use_auth_token + ner_labels = NERLabels(notation=self._notation, ner_types=self._ner_types) + self._label_list = ner_labels.get_label_list() + self._label_to_id = ner_labels.get_label_to_id() + self._id_to_label = ner_labels.get_id_to_label() + self._config = self.__get_config() + self._tokenizer = self.__get_tokenizer(do_lower_case=do_lower_case) + self._model, self._post_processor = self.__get_model(threshold=threshold) + self._dataset_tokenizer = None + # Data collator + self._data_collator = DataCollatorForTokenClassification( + self._tokenizer, + pad_to_multiple_of=8 if fp16 else None + ) + self._metrics_compute = None + self._train_dataset = None + self._eval_dataset = None + self._test_dataset = None + self._trainer = None + # Setup logging + self._logger = logging.getLogger(__name__) + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + log_level = logging.INFO if is_main_process(local_rank) else logging.WARN + self._logger.setLevel(log_level) + # Set the verbosity to info of the Transformers logger (on main process only): + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + # Set seed before initializing model. + self._seed = seed + set_seed(self._seed) + + def load( + self, + text_column_name: str = 'tokens', + label_column_name: str = 'labels', + pad_to_max_length: bool = False, + truncation: bool = True, + max_length: int = 512, + is_split_into_words: bool = True, + label_all_tokens: bool = False, + token_ignore_label: str = 'NA' + ): + # This following two lines of code is the one that is used to read the input dataset + # Run the subword tokenization on the pre-split tokens and then + # as mentioned above align the subtokens and labels and add the ignore + # label. This will read the input - say [60, year, old, in, 2080] + # and will return the subtokens - [60, year, old, in, 208, ##0] + # some other information like token_type_ids etc + # and the labels [0, 20, 20, 20, 3, -100] (0 - corresponds to B-AGE, 20 corresponds to O + # and 3 corresponds to B-DATE. This returned input serves as input for training the model + # or for gathering predictions from a trained model. + # Another important thing to note is that we have mentioned before that + # we add chunks of tokens that appear before and after the current chunk for context. We would + # also need to assign the label -100 (ignore_label) to these chunks, since we are using them + # only to provide context. For example the input would be something + # like tokens: [James, Doe, 60, year, old, in, 2080, BWH, tomorrow, only], + # labels: [NA, NA, B-DATE, O, O, O, B-DATE, NA, NA, NA]. NA represents the tokens used for context + # This function would return some tokenizer info (e.g attention mask etc), along with + # the information that maps the tokens to the subtokens - + # [James, Doe, 60, year, old, in, 208, ##0, BW, ##h, tomorrow, only] + # and the labels - [-100, -100, 0, 20, 20, 20, 3, -100, -100, -100] + # (if label_all_tokens was true, we would return [-100, -100, 0, 20, 20, 20, 3, 3, -100, -100]). + # Create an object that has the tokenize_and_align_labels function to perform + # the operation described above + # Map that sends B-Xxx label to its I-Xxx counterpart + b_to_i_label = [] + if label_all_tokens: + if self._notation != 'BIO': + raise ValueError('Label all tokens works only with BIO notation!') + b_to_i_label = [] + for idx, label in enumerate(self._label_list): + if label.startswith("B-") and label.replace("B-", "I-") in self._label_list: + b_to_i_label.append(self._label_list.index(label.replace("B-", "I-"))) + else: + b_to_i_label.append(idx) + # Padding strategy + padding = "max_length" if pad_to_max_length else False + self._dataset_tokenizer = DatasetTokenizer( + tokenizer=self._tokenizer, + token_column=text_column_name, + label_column=label_column_name, + label_to_id=self._label_to_id, + b_to_i_label=b_to_i_label, + padding=padding, + truncation=truncation, + max_length=max_length, + is_split_into_words=is_split_into_words, + label_all_tokens=label_all_tokens, + token_ignore_label=token_ignore_label + ) + + def set_train( + self, + train_file: str, + max_train_samples: Optional[int] = None, + preprocessing_num_workers: Optional[int] = None, + overwrite_cache: bool = False, + file_extension: str = 'json', + shuffle: bool = True, + ): + if shuffle: + train_dataset = load_dataset( + file_extension, + data_files={'train':train_file}, + cache_dir=self._cache_dir + ).shuffle(seed=self._seed) + else: + train_dataset = load_dataset( + file_extension, + data_files={'train':train_file}, + cache_dir=self._cache_dir + ) + train_dataset = train_dataset['train'] + # Run the tokenizer (subword), tokenize and align the labels as mentioned above on + # every example (row) of the dataset - (map function). This tokenized_datasets will be the + # input to the model (either for training or predictions + if max_train_samples is not None: + train_dataset = train_dataset.select(range(max_train_samples)) + self._train_dataset = train_dataset.map( + self._dataset_tokenizer.tokenize_and_align_labels, + batched=True, + num_proc=preprocessing_num_workers, + load_from_cache_file=not overwrite_cache, + ) + + def set_eval( + self, + validation_file: str, + max_val_samples: Optional[int] = None, + preprocessing_num_workers: Optional[int] = None, + overwrite_cache: bool = False, + file_extension: str = 'json', + shuffle: bool = True, + ): + if shuffle: + eval_dataset = load_dataset( + file_extension, + data_files={'eval':validation_file}, + cache_dir=self._cache_dir + ).shuffle(seed=self._seed) + else: + eval_dataset = load_dataset( + file_extension, + data_files={'eval':validation_file}, + cache_dir=self._cache_dir + ) + eval_dataset = eval_dataset['eval'] + # Eval + if max_val_samples is not None: + eval_dataset = eval_dataset.select(range(max_val_samples)) + self._eval_dataset = eval_dataset.map( + self._dataset_tokenizer.tokenize_and_align_labels, + batched=True, + num_proc=preprocessing_num_workers, + load_from_cache_file=not overwrite_cache, + ) + + def set_predict( + self, + test_file: str, + max_test_samples: Optional[int] = None, + preprocessing_num_workers: Optional[int] = None, + overwrite_cache: bool = False, + file_extension: str = 'json', + shuffle: bool = False, + ): + if shuffle: + test_dataset = load_dataset( + file_extension, + data_files={'test':test_file}, + cache_dir=self._cache_dir + ).shuffle(seed=self._seed) + else: + test_dataset = load_dataset( + file_extension, + data_files={'test':test_file}, + cache_dir=self._cache_dir + ) + test_dataset = test_dataset['test'] + # Eval + if max_test_samples is not None: + test_dataset = test_dataset.select(range(max_test_samples)) + self._test_dataset = test_dataset.map( + self._dataset_tokenizer.tokenize_and_align_labels, + batched=True, + num_proc=preprocessing_num_workers, + load_from_cache_file=not overwrite_cache, + ) + + def set_eval_metrics( + self, + validation_spans_file: str, + model_eval_script: str = './evaluation/note_evaluation.py', + ner_types_maps: Optional[Sequence[Sequence[str]]] = None, + evaluation_mode: str = 'strict' + ): + + if self._eval_dataset is None: + raise ValueError("Validation data not present") + + validation_ids = [json.loads(line)['note_id'] for line in open(validation_spans_file, 'r')] + validation_spans = [json.loads(line)['note_spans'] for line in open(validation_spans_file, 'r')] + descriptions = [''] + type_maps = [self._ner_types] + if ner_types_maps is not None: + descriptions += [''.join(list(set(ner_types_map) - set('O'))) for ner_types_map in ner_types_maps] + type_maps += ner_types_maps + label_mapper_list = [LabelMapper( + notation=self._notation, + ner_types=self._ner_types, + ner_types_maps=ner_types_map, + description=description + ) for ner_types_map, description in zip(type_maps, descriptions)] + # Use this to aggregate sentences back to notes for validation + note_level_aggregator = NoteLevelAggregator( + note_ids=validation_ids, + note_sent_info=self._eval_dataset['note_sent_info'] + ) + note_tokens = note_level_aggregator.get_aggregate_sequences( + sequences=self._eval_dataset['current_sent_info'] + ) + self._metrics_compute = MetricsCompute( + metric=load_metric(model_eval_script), + note_tokens=note_tokens, + note_spans=validation_spans, + label_mapper_list=label_mapper_list, + post_processor=self._post_processor, + note_level_aggregator=note_level_aggregator, + notation=self._notation, + mode=evaluation_mode, + confusion_matrix=False, + format_results=True + ) + + def setup_trainer(self, training_args): + # Log on each process the small summary: + self._logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + self._logger.info(f"Training/evaluation parameters {training_args}") + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir( + training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + self._logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + # Initialize our Trainer + self._trainer = Trainer( + model=self._model, + args=training_args, + train_dataset=self._train_dataset, + eval_dataset=self._eval_dataset, + tokenizer=self._tokenizer, + data_collator=self._data_collator, + compute_metrics=self._metrics_compute, + ) + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + return checkpoint + + def train(self, checkpoint: None): + if self._train_dataset is not None and self._trainer is not None: + train_result = self._trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + self._trainer.save_model() # Saves the tokenizer too for easy upload + metrics["train_samples"] = len(self._train_dataset) + + self._trainer.log_metrics("train", metrics) + self._trainer.save_metrics("train", metrics) + self._trainer.save_state() + else: + if self._trainer is None: + raise ValueError('Trainer not setup - Run setup_trainer') + else: + raise ValueError('Train data not setup - Run set_train') + return metrics + + def evaluate(self): + if self._eval_dataset is not None and self._trainer is not None: + # Evaluation + self._logger.info("*** Evaluate ***") + metrics = self._trainer.evaluate() + metrics["eval_samples"] = len(self._eval_dataset) + self._trainer.log_metrics("eval", metrics) + self._trainer.save_metrics("eval", metrics) + else: + if self._trainer is None: + raise ValueError('Trainer not setup - Run setup_trainer') + else: + raise ValueError('Evaluation data not setup - Run set_eval') + return metrics + + def predict(self, output_predictions_file: Optional[str] = None): + if self._test_dataset is not None and self._trainer is not None: + self._logger.info("*** Predict ***") + predictions, labels, metrics = self._trainer.predict(self._test_dataset, metric_key_prefix="predict") + unique_note_ids = set() + for note_sent_info in self._test_dataset['note_sent_info']: + note_id = note_sent_info['note_id'] + unique_note_ids = unique_note_ids | {note_id} + note_ids = list(unique_note_ids) + note_level_aggregator = NoteLevelAggregator( + note_ids=note_ids, + note_sent_info=self._test_dataset['note_sent_info'] + ) + note_tokens = note_level_aggregator.get_aggregate_sequences( + sequences=self._test_dataset['current_sent_info'] + ) + true_predictions, true_labels = self._post_processor.decode(predictions, labels) + note_predictions = note_level_aggregator.get_aggregate_sequences(sequences=true_predictions) + note_labels = note_level_aggregator.get_aggregate_sequences(sequences=true_labels) + self._trainer.log_metrics("test", metrics) + self._trainer.save_metrics("test", metrics) + if output_predictions_file is not None: + # Save predictions + with open(output_predictions_file, "w") as file: + for note_id, note_token, note_label, note_prediction in zip( + note_ids, + note_tokens, + note_labels, + note_predictions + ): + prediction_info = { + 'note_id': note_id, + 'tokens': note_token, + 'labels': note_label, + 'predictions': note_prediction + } + file.write(json.dumps(prediction_info) + '\n') + + else: + if self._trainer is None: + raise ValueError('Trainer not setup - Run setup_trainer') + else: + raise ValueError('Test data not setup - Run set_predict') + + def __get_config(self): + return AutoConfig.from_pretrained( + self._config_name, + num_labels=len(self._label_to_id.keys()), + label2id=self._label_to_id, + id2label=self._id_to_label, + finetuning_task=self._task_name, + cache_dir=self._cache_dir, + revision=self._model_revision, + use_auth_token=self._use_auth_token, + ) + + def __get_tokenizer(self, do_lower_case=False): + if self._config is None: + raise ValueError('Model config not initialized') + if self._config.model_type in {"gpt2", "roberta"}: + tokenizer = AutoTokenizer.from_pretrained( + self._tokenizer_name, + cache_dir=self._cache_dir, + use_fast=True, + do_lower_case=do_lower_case, + revision=self._model_revision, + use_auth_token=self._use_auth_token, + add_prefix_space=True, + ) + else: + tokenizer = AutoTokenizer.from_pretrained( + self._tokenizer_name, + cache_dir=self._cache_dir, + use_fast=True, + do_lower_case=do_lower_case, + revision=self._model_revision, + use_auth_token=self._use_auth_token, + ) + # Tokenizer check: this script requires a fast tokenizer. + if not isinstance(tokenizer, PreTrainedTokenizerFast): + raise ValueError( + "This example script only works for models that have a fast tokenizer. " + "Checkout the big table of models ""at https://huggingface.co/transformers/index.html " + "#bigtable to find the model types that meet this requirement") + return tokenizer + + def __get_model(self, threshold: Optional[float] = None): + # Get the model + post_process_picker = PostProcessPicker(label_list=self._label_list) + model_picker = ModelPicker( + model_name_or_path=self._model_name_or_path, + config=self._config, + cache_dir=self._cache_dir, + model_revision=self._model_revision, + use_auth_token=self._use_auth_token + ) + if self._post_process == 'argmax': + model = model_picker.get_argmax_bert_model() + post_processor = post_process_picker.get_argmax() + elif self._post_process == 'threshold': + model = model_picker.get_argmax_bert_model() + post_processor = post_process_picker.get_threshold_max(threshold=threshold) + elif self._post_process == 'logits': + model = model_picker.get_argmax_bert_model() + post_processor = post_process_picker.get_logits() + elif self._post_process == 'crf': + model = model_picker.get_crf_bert_model(notation=self._notation, id_to_label=self._id_to_label) + post_processor = post_process_picker.get_crf() + post_processor.set_crf(model.crf) + else: + raise ValueError('Invalid post_process argument') + return model, post_processor + + +def main(): + parser = HfArgumentParser(( + ModelArguments, + DataTrainingArguments, + EvaluationArguments, + TrainingArguments + )) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, evaluation_args, training_args = \ + parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, evaluation_args, training_args = \ + parser.parse_args_into_dataclasses() + + sequence_tagger = SequenceTagger( + task_name=data_args.task_name, + notation=data_args.notation, + ner_types=data_args.ner_types, + model_name_or_path=model_args.model_name_or_path, + config_name=model_args.config_name, + tokenizer_name=model_args.tokenizer_name, + post_process=model_args.post_process, + cache_dir=model_args.cache_dir, + model_revision=model_args.model_revision, + use_auth_token=model_args.use_auth_token, + threshold=model_args.threshold, + do_lower_case=data_args.do_lower_case, + fp16=training_args.fp16, + seed=training_args.seed, + local_rank=training_args.local_rank + ) + sequence_tagger.load() + if training_args.do_train: + sequence_tagger.set_train( + train_file=data_args.train_file, + max_train_samples=data_args.max_train_samples, + preprocessing_num_workers=data_args.preprocessing_num_workers, + overwrite_cache=data_args.overwrite_cache + ) + if training_args.do_eval: + sequence_tagger.set_eval( + validation_file=data_args.validation_file, + max_val_samples=data_args.max_eval_samples, + preprocessing_num_workers=data_args.preprocessing_num_workers, + overwrite_cache=data_args.overwrite_cache + ) + sequence_tagger.set_eval_metrics( + validation_spans_file=evaluation_args.validation_spans_file, + model_eval_script=evaluation_args.model_eval_script, + ner_types_maps=evaluation_args.ner_type_maps, + evaluation_mode=evaluation_args.evaluation_mode + ) + if training_args.do_predict: + sequence_tagger.set_predict( + test_file=data_args.test_file, + max_test_samples=data_args.max_predict_samples, + preprocessing_num_workers=data_args.preprocessing_num_workers, + overwrite_cache=data_args.overwrite_cache + ) + sequence_tagger.setup_trainer(training_args=training_args) + if training_args.do_train: + sequence_tagger.train(checkpoint=training_args.resume_from_checkpoint) + if training_args.do_eval: + sequence_tagger.evaluate() + if training_args.do_predict: + sequence_tagger.predict(output_predictions_file=data_args.output_predictions_file) + #for i in sequence_tagger.predict(output_predictions_file=None): + # print(i) + + +if __name__ == '__main__': + main()