"""NER annotation module using GLiNER models.""" from typing import List, Dict, Union, Optional import torch import random from gliner import GLiNER from ..utils.text_processing import tokenize_text class AutoAnnotator: """A class for automatic NER annotation using GLiNER models.""" def __init__( self, model: str = "BookingCare/gliner-multi-healthcare", device: Optional[torch.device] = None ) -> None: """Initialize the annotator with a GLiNER model. Args: model: Name or path of the GLiNER model to use device: Device to run the model on (CPU/GPU) """ if device is None: device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # Set PyTorch memory management settings if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.set_per_process_memory_fraction(0.8) # Use 80% of available GPU memory self.model = GLiNER.from_pretrained(model).to(device) self.annotated_data = [] self.stat = { "total": None, "current": -1 } def auto_annotate( self, data: List[str], labels: List[str], prompt: Optional[Union[str, List[str]]] = None, threshold: float = 0.5, nested_ner: bool = False ) -> List[Dict]: """Annotate a list of texts with NER labels. Args: data: List of texts to annotate labels: List of entity labels to detect prompt: Optional prompt or list of prompts to use threshold: Confidence threshold for entity detection nested_ner: Whether to allow nested entities Returns: List of annotated examples """ self.stat["total"] = len(data) self.stat["current"] = -1 # Process texts in batches processed_data = [] batch_size = 8 # Reduced batch size to prevent OOM errors for i in range(0, len(data), batch_size): batch_texts = data[i:i + batch_size] batch_with_prompts = [] # Add prompts to batch texts for text in batch_texts: if isinstance(prompt, list): prompt_text = random.choice(prompt) else: prompt_text = prompt text_with_prompt = f"{prompt_text}\n{text}" if prompt_text else text batch_with_prompts.append(text_with_prompt) # Process batch batch_results = self._batch_annotate_text( batch_with_prompts, labels, threshold, nested_ner ) processed_data.extend(batch_results) # Clear CUDA cache after each batch if torch.cuda.is_available(): torch.cuda.empty_cache() # Update progress self.stat["current"] = min(i + batch_size, len(data)) self.annotated_data = processed_data return self.annotated_data def _batch_annotate_text( self, texts: List[str], labels: List[str], threshold: float, nested_ner: bool ) -> List[Dict]: """Annotate multiple texts in batch. Args: texts: List of texts to annotate labels: List of entity labels threshold: Confidence threshold nested_ner: Whether to allow nested entities Returns: List of annotated examples """ batch_entities = self.model.batch_predict_entities( texts, labels, flat_ner=not nested_ner, threshold=threshold ) results = [] for text, entities in zip(texts, batch_entities): r = { "text": text, "entities": [ { "entity": entity["label"], "word": entity["text"], "start": entity["start"], "end": entity["end"], "score": 0, } for entity in entities ], } r["entities"] = self._merge_entities(r["entities"]) results.append(self._transform_data(r)) return results def _merge_entities(self, entities: List[Dict]) -> List[Dict]: """Merge adjacent entities of the same type. Args: entities: List of entity dictionaries Returns: List of merged entities """ if not entities: return [] merged = [] current = entities[0] for next_entity in entities[1:]: if (next_entity['entity'] == current['entity'] and (next_entity['start'] == current['end'] + 1 or next_entity['start'] == current['end'])): current['word'] += ' ' + next_entity['word'] current['end'] = next_entity['end'] else: merged.append(current) current = next_entity merged.append(current) return merged def _transform_data(self, data: Dict) -> Dict: """Transform raw annotation data into tokenized format. Args: data: Raw annotation data Returns: Transformed data with tokenized text and NER spans """ tokens = tokenize_text(data['text']) spans = [] for entity in data['entities']: entity_tokens = tokenize_text(entity['word']) entity_length = len(entity_tokens) # Find the start and end indices of each entity in the tokenized text for i in range(len(tokens) - entity_length + 1): if tokens[i:i + entity_length] == entity_tokens: spans.append([i, i + entity_length - 1, entity['entity']]) break return { "tokenized_text": tokens, "ner": spans, "validated": False }