0-shot-NER / ner.py
mvy
add validations checks
8e19b14
raw
history blame
5.1 kB
from typing import Tuple
import string
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
import spacy
import torch
import gradio as gr
class NER:
prompt: str = """
Identify entities in the text having the following classes:
{}
Text:
"""
def __init__(
self,
model_name: str,
sents_batch: int=10,
tokens_limit: int=2048
):
self.sents_batch = sents_batch
self.tokens_limit = tokens_limit
self.nlp: spacy.Language = spacy.load(
'en_core_web_sm',
disable = ['lemmatizer', 'parser', 'tagger', 'ner']
)
self.nlp.add_pipe('sentencizer')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
self.pipeline = pipeline(
"ner",
model=model,
tokenizer=self.tokenizer,
aggregation_strategy='first',
batch_size=12,
device=device
)
def get_last_sentence_id(self, i: int, sentences_len: int) -> int:
return min(i + self.sents_batch, sentences_len) - 1
def chunkanize(self, text: str) -> Tuple[list[str], list[int]]:
doc = self.nlp(text)
chunks = []
starts = []
sentences = list(doc.sents)
for i in range(0, len(sentences), self.sents_batch):
start = sentences[i].start_char
starts.append(start)
last_sentence = self.get_last_sentence_id(i, len(sentences))
end = sentences[last_sentence].end_char
chunks.append(text[start:end])
return chunks, starts
def get_inputs(
self, chunks: list[str], labels: list[str]
) -> Tuple[list[str], list[int]]:
inputs = []
prompts_lens = []
for label in labels:
prompt = self.prompt.format(label)
prompts_lens.append(len(prompt))
for chunk in chunks:
inputs.append(prompt + chunk)
return inputs, prompts_lens
@classmethod
def clean_span(
cls, start: int, end: int, span: str
) -> Tuple[int, int, str]:
if len(span) >= 1:
if span[0] in string.punctuation:
return cls.clean_span(start+1, end, span[1:])
if span[-1] in string.punctuation:
return cls.clean_span(start, end-1, span[:-1])
return start, end, span.strip()
def predict(
self,
text: str,
inputs: list[str],
labels: list[str],
chunks_starts: list[int],
prompts_lens: list[int],
threshold: float
) -> list[dict[str, any]]:
outputs = []
for id, output in enumerate(self.pipeline(inputs)):
label = labels[id//len(chunks_starts)]
shift = chunks_starts[id%len(chunks_starts)] - prompts_lens[id//len(chunks_starts)]
for ent in output:
start = ent['start'] + shift + 1
end = ent['end'] + shift
start, end, span = self.clean_span(start, end, text[start:end])
if not span:
continue
if ent['score'] >= threshold:
outputs.append({
'span': span,
'start': start,
'end': end,
'entity': label
})
return outputs
def check_text(self, text: str) -> None:
if not text:
raise gr.Error('No text provided. Please provide text.')
def check_labels(self, labels: list[str]) -> None:
if not labels:
raise gr.Error(
'No labels provided. Please provide labels.'
' Multiple labels should be divided by commas.'
' See examples below.'
)
def check_tokens_limit(self, inputs: list[str]) -> None:
tokens = 0
for input_ in inputs:
tokens += len(self.tokenizer.encode(input_))
if tokens > self.tokens_limit:
raise gr.Error(
'Too many tokens! Please reduce size of text or amount of labels.'
f' Max tokens count is: {self.tokens_limit}.'
)
def process(
self, labels: str, text: str, threshold: float=0.
) -> dict[str, any]:
labels_list = list({
l for label in labels.split(',')
if (l:=label.strip())
})
self.check_labels(labels_list)
self.check_text(text)
chunks, chunks_starts = self.chunkanize(text)
inputs, prompts_lens = self.get_inputs(chunks, labels_list)
self.check_tokens_limit(inputs)
outputs = self.predict(
text, inputs, labels_list, chunks_starts, prompts_lens, threshold
)
return {"text": text, "entities": outputs}