Spaces:
Runtime error
Runtime error
""" | |
Source: https://github.com/ZurichNLP/recognizing-semantic-differences | |
MIT License | |
Copyright (c) 2023 University of Zurich | |
""" | |
import itertools | |
from typing import List, Union | |
import torch | |
import transformers | |
from transformers import FeatureExtractionPipeline, Pipeline | |
from recognizers.base import DifferenceRecognizer | |
from recognizers.utils import DifferenceSample | |
Ngram = List[int] # A span of subword indices | |
class FeatureExtractionRecognizer(DifferenceRecognizer): | |
def __init__(self, | |
model_name_or_path: str = None, | |
pipeline: Union[FeatureExtractionPipeline, Pipeline] = None, | |
layer: int = -1, | |
batch_size: int = 16, | |
): | |
assert model_name_or_path is not None or pipeline is not None | |
if pipeline is None: | |
pipeline = transformers.pipeline( | |
model=model_name_or_path, | |
task="feature-extraction", | |
) | |
self.pipeline = pipeline | |
self.layer = layer | |
self.batch_size = batch_size | |
def encode_batch(self, sentences: List[str], **kwargs) -> torch.Tensor: | |
model_inputs = self.pipeline.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True) | |
model_inputs = model_inputs.to(self.pipeline.device) | |
outputs = self.pipeline.model(**model_inputs, output_hidden_states=True, **kwargs) | |
return outputs.hidden_states[self.layer] | |
def predict(self, | |
a: str, | |
b: str, | |
**kwargs, | |
) -> DifferenceSample: | |
return self.predict_all([a], [b], **kwargs)[0] | |
def predict_all(self, | |
a: List[str], | |
b: List[str], | |
**kwargs, | |
) -> List[DifferenceSample]: | |
samples = [] | |
for i in range(0, len(a), self.batch_size): | |
samples.extend(self._predict_all( | |
a[i:i + self.batch_size], | |
b[i:i + self.batch_size], | |
**kwargs, | |
)) | |
return samples | |
def _predict_all(self, | |
a: List[str], | |
b: List[str], | |
**kwargs, | |
) -> List[DifferenceSample]: | |
raise NotImplementedError | |
def _pool(self, token_embeddings: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | |
""" | |
:param token_embeddings: batch x seq_len x dim | |
:param mask: batch x seq_len; 1 if token should be included in the pooling | |
:return: batch x dim | |
Do only sum and do not divide by the number of tokens because cosine similarity is length-invariant. | |
""" | |
return torch.sum(token_embeddings * mask.unsqueeze(-1), dim=1) | |
def _get_subwords_by_word(self, sentence: str) -> List[Ngram]: | |
""" | |
:return: For each word in the sentence, the positions of the subwords that make up the word. | |
""" | |
batch_encoding = self.pipeline.tokenizer( | |
sentence, | |
padding=True, | |
truncation=True, | |
) | |
subword_ids: List[List[int]] = [] | |
for subword_idx in range(len(batch_encoding.encodings[0].word_ids)): | |
if batch_encoding.encodings[0].word_ids[subword_idx] is None: # Special token | |
continue | |
char_idx = batch_encoding.encodings[0].offsets[subword_idx][0] | |
if isinstance(self.pipeline.tokenizer, transformers.XLMRobertaTokenizerFast) or \ | |
isinstance(self.pipeline.tokenizer, transformers.XLMRobertaTokenizer): | |
token = batch_encoding.encodings[0].tokens[subword_idx] | |
is_tail = not token.startswith("▁") and token not in self.pipeline.tokenizer.all_special_tokens | |
elif isinstance(self.pipeline.tokenizer, transformers.RobertaTokenizerFast) or \ | |
isinstance(self.pipeline.tokenizer, transformers.RobertaTokenizer): | |
token = batch_encoding.encodings[0].tokens[subword_idx] | |
is_tail = not token.startswith("Ġ") and token not in self.pipeline.tokenizer.all_special_tokens | |
else: | |
is_tail = char_idx > 0 and char_idx == batch_encoding.encodings[0].offsets[subword_idx - 1][1] | |
if is_tail and len(subword_ids) > 0: | |
subword_ids[-1].append(subword_idx) | |
else: | |
subword_ids.append([subword_idx]) | |
return subword_ids | |
def _get_ngrams(self, subwords_by_word: List[Ngram]) -> List[Ngram]: | |
""" | |
:return: For each subword ngram in the sentence, the positions of the subwords that make up the ngram. | |
""" | |
subwords = list(itertools.chain.from_iterable(subwords_by_word)) | |
# Always return at least one ngram (reduce n if necessary) | |
min_n = min(self.min_n, len(subwords)) | |
ngrams = [] | |
for n in range(min_n, self.max_n + 1): | |
for i in range(len(subwords) - n + 1): | |
ngrams.append(subwords[i:i + n]) | |
return ngrams | |
def _subword_labels_to_word_labels(self, subword_labels: torch.Tensor, subwords_by_words: List[Ngram]) -> List[float]: | |
""" | |
:param subword_labels: num_subwords | |
:param subwords_by_words: num_words x num_subwords | |
:return: num_words | |
""" | |
labels = [] | |
for subword_indices in subwords_by_words: | |
label = subword_labels[subword_indices].mean().item() | |
labels.append(label) | |
return labels | |