Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| import argparse | |
| import json | |
| from pathlib import Path | |
| import re | |
| from typing import Dict, Optional, Union | |
| import torch | |
| import torch.nn.functional as F | |
| from modules.layers import LstmSeq2SeqEncoder | |
| from modules.base import InstructBase | |
| from modules.evaluator import Evaluator, greedy_search | |
| from modules.span_rep import SpanRepLayer | |
| from modules.token_rep import TokenRepLayer | |
| from torch import nn | |
| from torch.nn.utils.rnn import pad_sequence | |
| from huggingface_hub import PyTorchModelHubMixin, hf_hub_download | |
| from huggingface_hub.utils import HfHubHTTPError | |
| class GLiNER(InstructBase, PyTorchModelHubMixin): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| # [ENT] token | |
| self.entity_token = "<<ENT>>" | |
| self.sep_token = "<<SEP>>" | |
| # usually a pretrained bidirectional transformer, returns first subtoken representation | |
| self.token_rep_layer = TokenRepLayer(model_name=config.model_name, fine_tune=config.fine_tune, | |
| subtoken_pooling=config.subtoken_pooling, hidden_size=config.hidden_size, | |
| add_tokens=[self.entity_token, self.sep_token]) | |
| # hierarchical representation of tokens | |
| self.rnn = LstmSeq2SeqEncoder( | |
| input_size=config.hidden_size, | |
| hidden_size=config.hidden_size // 2, | |
| num_layers=1, | |
| bidirectional=True, | |
| ) | |
| # span representation | |
| self.span_rep_layer = SpanRepLayer( | |
| span_mode=config.span_mode, | |
| hidden_size=config.hidden_size, | |
| max_width=config.max_width, | |
| dropout=config.dropout, | |
| ) | |
| # prompt representation (FFN) | |
| self.prompt_rep_layer = nn.Sequential( | |
| nn.Linear(config.hidden_size, config.hidden_size * 4), | |
| nn.Dropout(config.dropout), | |
| nn.ReLU(), | |
| nn.Linear(config.hidden_size * 4, config.hidden_size) | |
| ) | |
| def compute_score_train(self, x): | |
| span_idx = x['span_idx'] * x['span_mask'].unsqueeze(-1) | |
| new_length = x['seq_length'].clone() | |
| new_tokens = [] | |
| all_len_prompt = [] | |
| num_classes_all = [] | |
| # add prompt to the tokens | |
| for i in range(len(x['tokens'])): | |
| all_types_i = list(x['classes_to_id'][i].keys()) | |
| # multiple entity types in all_types. Prompt is appended at the start of tokens | |
| entity_prompt = [] | |
| num_classes_all.append(len(all_types_i)) | |
| # add enity types to prompt | |
| for entity_type in all_types_i: | |
| entity_prompt.append(self.entity_token) # [ENT] token | |
| entity_prompt.append(entity_type) # entity type | |
| entity_prompt.append(self.sep_token) # [SEP] token | |
| # prompt format: | |
| # [ENT] entity_type [ENT] entity_type ... [ENT] entity_type [SEP] | |
| # add prompt to the tokens | |
| tokens_p = entity_prompt + x['tokens'][i] | |
| # input format: | |
| # [ENT] entity_type_1 [ENT] entity_type_2 ... [ENT] entity_type_m [SEP] token_1 token_2 ... token_n | |
| # update length of the sequence (add prompt length to the original length) | |
| new_length[i] = new_length[i] + len(entity_prompt) | |
| # update tokens | |
| new_tokens.append(tokens_p) | |
| # store prompt length | |
| all_len_prompt.append(len(entity_prompt)) | |
| # create a mask using num_classes_all (0, if it exceeds the number of classes, 1 otherwise) | |
| max_num_classes = max(num_classes_all) | |
| entity_type_mask = torch.arange(max_num_classes).unsqueeze(0).expand(len(num_classes_all), -1).to( | |
| x['span_mask'].device) | |
| entity_type_mask = entity_type_mask < torch.tensor(num_classes_all).unsqueeze(-1).to( | |
| x['span_mask'].device) # [batch_size, max_num_classes] | |
| # compute all token representations | |
| bert_output = self.token_rep_layer(new_tokens, new_length) | |
| word_rep_w_prompt = bert_output["embeddings"] # embeddings for all tokens (with prompt) | |
| mask_w_prompt = bert_output["mask"] # mask for all tokens (with prompt) | |
| # get word representation (after [SEP]), mask (after [SEP]) and entity type representation (before [SEP]) | |
| word_rep = [] # word representation (after [SEP]) | |
| mask = [] # mask (after [SEP]) | |
| entity_type_rep = [] # entity type representation (before [SEP]) | |
| for i in range(len(x['tokens'])): | |
| prompt_entity_length = all_len_prompt[i] # length of prompt for this example | |
| # get word representation (after [SEP]) | |
| word_rep.append(word_rep_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]]) | |
| # get mask (after [SEP]) | |
| mask.append(mask_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]]) | |
| # get entity type representation (before [SEP]) | |
| entity_rep = word_rep_w_prompt[i, :prompt_entity_length - 1] # remove [SEP] | |
| entity_rep = entity_rep[0::2] # it means that we take every second element starting from the second one | |
| entity_type_rep.append(entity_rep) | |
| # padding for word_rep, mask and entity_type_rep | |
| word_rep = pad_sequence(word_rep, batch_first=True) # [batch_size, seq_len, hidden_size] | |
| mask = pad_sequence(mask, batch_first=True) # [batch_size, seq_len] | |
| entity_type_rep = pad_sequence(entity_type_rep, batch_first=True) # [batch_size, len_types, hidden_size] | |
| # compute span representation | |
| word_rep = self.rnn(word_rep, mask) | |
| span_rep = self.span_rep_layer(word_rep, span_idx) | |
| # compute final entity type representation (FFN) | |
| entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size) | |
| num_classes = entity_type_rep.shape[1] # number of entity types | |
| # similarity score | |
| scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep) | |
| return scores, num_classes, entity_type_mask | |
| def forward(self, x): | |
| # compute span representation | |
| scores, num_classes, entity_type_mask = self.compute_score_train(x) | |
| batch_size = scores.shape[0] | |
| # loss for filtering classifier | |
| logits_label = scores.view(-1, num_classes) | |
| labels = x["span_label"].view(-1) # (batch_size * num_spans) | |
| mask_label = labels != -1 # (batch_size * num_spans) | |
| labels.masked_fill_(~mask_label, 0) # Set the labels of padding tokens to 0 | |
| # one-hot encoding | |
| labels_one_hot = torch.zeros(labels.size(0), num_classes + 1, dtype=torch.float32).to(scores.device) | |
| labels_one_hot.scatter_(1, labels.unsqueeze(1), 1) # Set the corresponding index to 1 | |
| labels_one_hot = labels_one_hot[:, 1:] # Remove the first column | |
| # Shape of labels_one_hot: (batch_size * num_spans, num_classes) | |
| # compute loss (without reduction) | |
| all_losses = F.binary_cross_entropy_with_logits(logits_label, labels_one_hot, | |
| reduction='none') | |
| # mask loss using entity_type_mask (B, C) | |
| masked_loss = all_losses.view(batch_size, -1, num_classes) * entity_type_mask.unsqueeze(1) | |
| all_losses = masked_loss.view(-1, num_classes) | |
| # expand mask_label to all_losses | |
| mask_label = mask_label.unsqueeze(-1).expand_as(all_losses) | |
| # put lower loss for in label_one_hot (2 for positive, 1 for negative) | |
| weight_c = labels_one_hot + 1 | |
| # apply mask | |
| all_losses = all_losses * mask_label.float() * weight_c | |
| return all_losses.sum() | |
| def compute_score_eval(self, x, device): | |
| # check if classes_to_id is dict | |
| assert isinstance(x['classes_to_id'], dict), "classes_to_id must be a dict" | |
| span_idx = (x['span_idx'] * x['span_mask'].unsqueeze(-1)).to(device) | |
| all_types = list(x['classes_to_id'].keys()) | |
| # multiple entity types in all_types. Prompt is appended at the start of tokens | |
| entity_prompt = [] | |
| # add enity types to prompt | |
| for entity_type in all_types: | |
| entity_prompt.append(self.entity_token) | |
| entity_prompt.append(entity_type) | |
| entity_prompt.append(self.sep_token) | |
| prompt_entity_length = len(entity_prompt) | |
| # add prompt | |
| tokens_p = [entity_prompt + tokens for tokens in x['tokens']] | |
| seq_length_p = x['seq_length'] + prompt_entity_length | |
| out = self.token_rep_layer(tokens_p, seq_length_p) | |
| word_rep_w_prompt = out["embeddings"] | |
| mask_w_prompt = out["mask"] | |
| # remove prompt | |
| word_rep = word_rep_w_prompt[:, prompt_entity_length:, :] | |
| mask = mask_w_prompt[:, prompt_entity_length:] | |
| # get_entity_type_rep | |
| entity_type_rep = word_rep_w_prompt[:, :prompt_entity_length - 1, :] | |
| # extract [ENT] tokens (which are at even positions in entity_type_rep) | |
| entity_type_rep = entity_type_rep[:, 0::2, :] | |
| entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size) | |
| word_rep = self.rnn(word_rep, mask) | |
| span_rep = self.span_rep_layer(word_rep, span_idx) | |
| local_scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep) | |
| return local_scores | |
| def predict(self, x, flat_ner=False, threshold=0.5): | |
| self.eval() | |
| local_scores = self.compute_score_eval(x, device=next(self.parameters()).device) | |
| spans = [] | |
| for i, _ in enumerate(x["tokens"]): | |
| local_i = local_scores[i] | |
| wh_i = [i.tolist() for i in torch.where(torch.sigmoid(local_i) > threshold)] | |
| span_i = [] | |
| for s, k, c in zip(*wh_i): | |
| if s + k < len(x["tokens"][i]): | |
| span_i.append((s, s + k, x["id_to_classes"][c + 1], local_i[s, k, c])) | |
| span_i = greedy_search(span_i, flat_ner) | |
| spans.append(span_i) | |
| return spans | |
| def predict_entities(self, text, labels, flat_ner=True, threshold=0.5): | |
| tokens = [] | |
| start_token_idx_to_text_idx = [] | |
| end_token_idx_to_text_idx = [] | |
| for match in re.finditer(r'\w+(?:[-_]\w+)*|\S', text): | |
| tokens.append(match.group()) | |
| start_token_idx_to_text_idx.append(match.start()) | |
| end_token_idx_to_text_idx.append(match.end()) | |
| input_x = {"tokenized_text": tokens, "ner": None} | |
| x = self.collate_fn([input_x], labels) | |
| output = self.predict(x, flat_ner=flat_ner, threshold=threshold) | |
| entities = [] | |
| for start_token_idx, end_token_idx, ent_type in output[0]: | |
| start_text_idx = start_token_idx_to_text_idx[start_token_idx] | |
| end_text_idx = end_token_idx_to_text_idx[end_token_idx] | |
| entities.append({ | |
| "start": start_token_idx_to_text_idx[start_token_idx], | |
| "end": end_token_idx_to_text_idx[end_token_idx], | |
| "text": text[start_text_idx:end_text_idx], | |
| "label": ent_type, | |
| }) | |
| return entities | |
| def evaluate(self, test_data, flat_ner=False, threshold=0.5, batch_size=12, entity_types=None): | |
| self.eval() | |
| data_loader = self.create_dataloader(test_data, batch_size=batch_size, entity_types=entity_types, shuffle=False) | |
| device = next(self.parameters()).device | |
| all_preds = [] | |
| all_trues = [] | |
| for x in data_loader: | |
| for k, v in x.items(): | |
| if isinstance(v, torch.Tensor): | |
| x[k] = v.to(device) | |
| batch_predictions = self.predict(x, flat_ner, threshold) | |
| all_preds.extend(batch_predictions) | |
| all_trues.extend(x["entities"]) | |
| evaluator = Evaluator(all_trues, all_preds) | |
| out, f1 = evaluator.evaluate() | |
| return out, f1 | |
| def _from_pretrained( | |
| cls, | |
| *, | |
| model_id: str, | |
| revision: Optional[str], | |
| cache_dir: Optional[Union[str, Path]], | |
| force_download: bool, | |
| proxies: Optional[Dict], | |
| resume_download: bool, | |
| local_files_only: bool, | |
| token: Union[str, bool, None], | |
| map_location: str = "cpu", | |
| strict: bool = False, | |
| **model_kwargs, | |
| ): | |
| # 1. Backwards compatibility: Use "gliner_base.pt" and "gliner_multi.pt" with all data | |
| filenames = ["gliner_base.pt", "gliner_multi.pt"] | |
| for filename in filenames: | |
| model_file = Path(model_id) / filename | |
| if not model_file.exists(): | |
| try: | |
| model_file = hf_hub_download( | |
| repo_id=model_id, | |
| filename=filename, | |
| revision=revision, | |
| cache_dir=cache_dir, | |
| force_download=force_download, | |
| proxies=proxies, | |
| resume_download=resume_download, | |
| token=token, | |
| local_files_only=local_files_only, | |
| ) | |
| except HfHubHTTPError: | |
| continue | |
| dict_load = torch.load(model_file, map_location=torch.device(map_location)) | |
| config = dict_load["config"] | |
| state_dict = dict_load["model_weights"] | |
| config.model_name = "microsoft/deberta-v3-base" if filename == "gliner_base.pt" else "microsoft/mdeberta-v3-base" | |
| model = cls(config) | |
| model.load_state_dict(state_dict, strict=strict, assign=True) | |
| # Required to update flair's internals as well: | |
| model.to(map_location) | |
| return model | |
| # 2. Newer format: Use "pytorch_model.bin" and "gliner_config.json" | |
| from train import load_config_as_namespace | |
| model_file = Path(model_id) / "pytorch_model.bin" | |
| if not model_file.exists(): | |
| model_file = hf_hub_download( | |
| repo_id=model_id, | |
| filename="pytorch_model.bin", | |
| revision=revision, | |
| cache_dir=cache_dir, | |
| force_download=force_download, | |
| proxies=proxies, | |
| resume_download=resume_download, | |
| token=token, | |
| local_files_only=local_files_only, | |
| ) | |
| config_file = Path(model_id) / "gliner_config.json" | |
| if not config_file.exists(): | |
| config_file = hf_hub_download( | |
| repo_id=model_id, | |
| filename="gliner_config.json", | |
| revision=revision, | |
| cache_dir=cache_dir, | |
| force_download=force_download, | |
| proxies=proxies, | |
| resume_download=resume_download, | |
| token=token, | |
| local_files_only=local_files_only, | |
| ) | |
| config = load_config_as_namespace(config_file) | |
| model = cls(config) | |
| state_dict = torch.load(model_file, map_location=torch.device(map_location)) | |
| model.load_state_dict(state_dict, strict=strict, assign=True) | |
| model.to(map_location) | |
| return model | |
| def save_pretrained( | |
| self, | |
| save_directory: Union[str, Path], | |
| *, | |
| config: Optional[Union[dict, "DataclassInstance"]] = None, | |
| repo_id: Optional[str] = None, | |
| push_to_hub: bool = False, | |
| **push_to_hub_kwargs, | |
| ) -> Optional[str]: | |
| """ | |
| Save weights in local directory. | |
| Args: | |
| save_directory (`str` or `Path`): | |
| Path to directory in which the model weights and configuration will be saved. | |
| config (`dict` or `DataclassInstance`, *optional*): | |
| Model configuration specified as a key/value dictionary or a dataclass instance. | |
| push_to_hub (`bool`, *optional*, defaults to `False`): | |
| Whether or not to push your model to the Huggingface Hub after saving it. | |
| repo_id (`str`, *optional*): | |
| ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if | |
| not provided. | |
| kwargs: | |
| Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method. | |
| """ | |
| save_directory = Path(save_directory) | |
| save_directory.mkdir(parents=True, exist_ok=True) | |
| # save model weights/files | |
| torch.save(self.state_dict(), save_directory / "pytorch_model.bin") | |
| # save config (if provided) | |
| if config is None: | |
| config = self.config | |
| if config is not None: | |
| if isinstance(config, argparse.Namespace): | |
| config = vars(config) | |
| (save_directory / "gliner_config.json").write_text(json.dumps(config, indent=2)) | |
| # push to the Hub if required | |
| if push_to_hub: | |
| kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input | |
| if config is not None: # kwarg for `push_to_hub` | |
| kwargs["config"] = config | |
| if repo_id is None: | |
| repo_id = save_directory.name # Defaults to `save_directory` name | |
| return self.push_to_hub(repo_id=repo_id, **kwargs) | |
| return None | |
| def to(self, device): | |
| super().to(device) | |
| import flair | |
| flair.device = device | |
| return self | |
