Spaces:
Runtime error
Runtime error
""" | |
Source: https://github.com/ZurichNLP/recognizing-semantic-differences | |
MIT License | |
Copyright (c) 2023 University of Zurich | |
""" | |
import itertools | |
from copy import deepcopy | |
from typing import Union, List | |
import torch | |
from transformers import Pipeline, FeatureExtractionPipeline | |
from recognizers.feature_based import FeatureExtractionRecognizer, Ngram | |
from recognizers.utils import DifferenceSample, pairwise_cos_sim, cos_sim | |
class DiffDel(FeatureExtractionRecognizer): | |
def __init__(self, | |
model_name_or_path: str = None, | |
pipeline: Union[FeatureExtractionPipeline, Pipeline] = None, | |
layer: int = -1, | |
batch_size: int = 16, | |
min_n: int = 1, | |
max_n: int = 1, # Inclusive | |
): | |
super().__init__(model_name_or_path, pipeline, layer, batch_size) | |
assert min_n <= max_n | |
self.min_n = min_n | |
self.max_n = max_n | |
def __str__(self): | |
return f"DiffDel(model={self.pipeline.model.name_or_path}, layer={self.layer}, " \ | |
f"min_n={self.min_n}, max_n={self.max_n})" | |
def _predict_all(self, | |
a: List[str], | |
b: List[str], | |
**kwargs, | |
) -> List[DifferenceSample]: | |
outputs_a = self.encode_batch(a, **kwargs) | |
outputs_b = self.encode_batch(b, **kwargs) | |
subwords_by_words_a = [self._get_subwords_by_word(sentence) for sentence in a] | |
subwords_by_words_b = [self._get_subwords_by_word(sentence) for sentence in b] | |
ngrams_a = [self._get_ngrams(subwords_by_word) for subwords_by_word in subwords_by_words_a] | |
ngrams_b = [self._get_ngrams(subwords_by_word) for subwords_by_word in subwords_by_words_b] | |
sentence_embeddings_a = self._get_full_sentence_embeddings(outputs_a, [list(itertools.chain.from_iterable(subwords)) for subwords in subwords_by_words_a]) | |
sentence_embeddings_b = self._get_full_sentence_embeddings(outputs_b, [list(itertools.chain.from_iterable(subwords)) for subwords in subwords_by_words_b]) | |
full_similarities = pairwise_cos_sim(sentence_embeddings_a, sentence_embeddings_b) | |
all_labels_a = [] | |
all_labels_b = [] | |
for i in range(len(a)): | |
partial_embeddings_a = self._get_partial_sentence_embeddings_for_sample(outputs_a[i], ngrams_a[i]) | |
partial_embeddings_b = self._get_partial_sentence_embeddings_for_sample(outputs_b[i], ngrams_b[i]) | |
partial_similarities_a = cos_sim(partial_embeddings_a, sentence_embeddings_b[i].unsqueeze(0)).squeeze(1) | |
partial_similarities_b = cos_sim(partial_embeddings_b, sentence_embeddings_a[i].unsqueeze(0)).squeeze(1) | |
ngram_labels_a = (partial_similarities_a - full_similarities[i] + 1) / 2 | |
ngram_labels_b = (partial_similarities_b - full_similarities[i] + 1) / 2 | |
subword_labels_a = self._distribute_ngram_labels_to_subwords(ngram_labels_a, ngrams_a[i]) | |
subword_labels_b = self._distribute_ngram_labels_to_subwords(ngram_labels_b, ngrams_b[i]) | |
labels_a = self._subword_labels_to_word_labels(subword_labels_a, subwords_by_words_a[i]) | |
labels_b = self._subword_labels_to_word_labels(subword_labels_b, subwords_by_words_b[i]) | |
all_labels_a.append(labels_a) | |
all_labels_b.append(labels_b) | |
samples = [] | |
for i in range(len(a)): | |
samples.append(DifferenceSample( | |
tokens_a=tuple(a[i].split()), | |
tokens_b=tuple(b[i].split()), | |
labels_a=tuple(all_labels_a[i]), | |
labels_b=tuple(all_labels_b[i]), | |
)) | |
return samples | |
def _get_full_sentence_embeddings(self, token_embeddings: torch.Tensor, include_subwords: List[List[int]]) -> torch.Tensor: | |
""" | |
:param token_embeddings: batch x seq_len x dim | |
:param include_subwords: batch x num_subwords | |
:return: A tensor of shape batch x dim | |
""" | |
pool_mask = torch.zeros(token_embeddings.shape[0], token_embeddings.shape[1], device=token_embeddings.device) | |
for i, subword_indices in enumerate(include_subwords): | |
pool_mask[i, subword_indices] = 1 | |
sentence_embeddings = self._pool(token_embeddings, pool_mask) | |
return sentence_embeddings | |
def _get_partial_sentence_embeddings_for_sample(self, token_embeddings: torch.Tensor, ngrams: List[Ngram]) -> torch.Tensor: | |
""" | |
:param token_embeddings: seq_len x dim | |
:param ngrams: num_ngrams x n | |
:return: A tensor of shape num_ngrams x dim | |
""" | |
pool_mask = torch.zeros(len(ngrams), token_embeddings.shape[0], device=token_embeddings.device) | |
pool_mask[:, list(itertools.chain.from_iterable(ngrams))] = 1 | |
for i, subword_indices in enumerate(ngrams): | |
pool_mask[i, subword_indices] = 0 | |
partial_embeddings = self._pool(token_embeddings.unsqueeze(0).repeat(len(ngrams), 1, 1), pool_mask) | |
return partial_embeddings | |
def _distribute_ngram_labels_to_subwords(self, ngram_labels: torch.Tensor, ngrams: List[Ngram]) -> torch.Tensor: | |
""" | |
:param ngram_labels: num_ngrams | |
:param ngrams: num_ngrams x n | |
:return: num_subwords | |
""" | |
max_subword_idx = max(itertools.chain.from_iterable(ngrams)) | |
subword_contributions = torch.zeros(max_subword_idx + 1, device=ngram_labels.device) | |
contribution_count = torch.zeros(max_subword_idx + 1, device=ngram_labels.device) | |
for i, ngram in enumerate(ngrams): | |
subword_contributions[ngram] += ngram_labels[i] / len(ngram) | |
contribution_count[ngram] += 1 / len(ngram) | |
subword_contributions /= contribution_count | |
return subword_contributions | |
class DiffDelWithReencode(FeatureExtractionRecognizer): | |
""" | |
Version of DiffDel that encodes the partial sentences from scratch (instead of encoding the full sentence once and | |
then excluding hidden states from the mean) | |
""" | |
def __init__(self, | |
model_name_or_path: str = None, | |
pipeline: Union[FeatureExtractionPipeline, Pipeline] = None, | |
layer: int = -1, | |
batch_size: int = 16, | |
): | |
super().__init__(model_name_or_path, pipeline, layer, batch_size) | |
def __str__(self): | |
return f"DiffDelWithReencode(model={self.pipeline.model.name_or_path}, layer={self.layer})" | |
def _predict_all(self, | |
a: List[str], | |
b: List[str], | |
**kwargs, | |
) -> List[DifferenceSample]: | |
a_words = [sentence.split() for sentence in a] | |
b_words = [sentence.split() for sentence in b] | |
a_words_partial = [] | |
b_words_partial = [] | |
for words in a_words: | |
for i, word in enumerate(words): | |
partial = deepcopy(words) | |
del partial[i] | |
a_words_partial.append(partial) | |
for words in b_words: | |
for i, word in enumerate(words): | |
partial = deepcopy(words) | |
del partial[i] | |
b_words_partial.append(partial) | |
a_partial = [" ".join([word for word in words if word]) for words in a_words_partial] | |
b_partial = [" ".join([word for word in words if word]) for words in b_words_partial] | |
a_num_partial = [len(words) for words in a_words] | |
b_num_partial = [len(words) for words in b_words] | |
a_embedding_full = self._encode_and_pool(a, **kwargs) | |
b_embedding_full = self._encode_and_pool(b, **kwargs) | |
a_embeddings_partial = [] | |
b_embeddings_partial = [] | |
for i in range(0, len(a_partial), self.batch_size): | |
a_embeddings_partial_batch = self._encode_and_pool(a_partial[i:i + self.batch_size], **kwargs) | |
a_embeddings_partial.append(a_embeddings_partial_batch) | |
for i in range(0, len(b_partial), self.batch_size): | |
b_embeddings_partial_batch = self._encode_and_pool(b_partial[i:i + self.batch_size], **kwargs) | |
b_embeddings_partial.append(b_embeddings_partial_batch) | |
a_embeddings_partial = torch.cat(a_embeddings_partial, dim=0) | |
b_embeddings_partial = torch.cat(b_embeddings_partial, dim=0) | |
labels_a = [] | |
labels_b = [] | |
similarity_full = pairwise_cos_sim(a_embedding_full, b_embedding_full) | |
for i in range(len(a)): | |
a_embeddings_partial_i = a_embeddings_partial[sum(a_num_partial[:i]):sum(a_num_partial[:i + 1])] | |
similarities_partial = pairwise_cos_sim(a_embeddings_partial_i, b_embedding_full[i].unsqueeze(0)).squeeze(0) | |
labels = (similarities_partial - similarity_full[i] + 1) / 2 | |
labels = labels.detach().cpu().tolist() | |
if isinstance(labels, float): | |
labels = [labels] | |
assert len(labels) == len(a_words[i]) | |
labels_a.append(labels) | |
for i in range(len(b)): | |
b_embeddings_partial_i = b_embeddings_partial[sum(b_num_partial[:i]):sum(b_num_partial[:i + 1])] | |
similarities_partial = pairwise_cos_sim(b_embeddings_partial_i, a_embedding_full[i].unsqueeze(0)).squeeze(0) | |
labels = (similarities_partial - similarity_full[i] + 1) / 2 | |
labels = labels.detach().cpu().tolist() | |
if isinstance(labels, float): | |
labels = [labels] | |
assert len(labels) == len(b_words[i]) | |
labels_b.append(labels) | |
samples = [] | |
for i in range(len(a)): | |
samples.append(DifferenceSample( | |
tokens_a=tuple(a_words[i]), | |
tokens_b=tuple(b_words[i]), | |
labels_a=tuple(labels_a[i]), | |
labels_b=tuple(labels_b[i]), | |
)) | |
return samples | |
def _encode_and_pool(self, sentences: List[str], **kwargs) -> torch.Tensor: | |
model_inputs = self.pipeline.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True) | |
model_inputs = model_inputs.to(self.pipeline.device) | |
outputs = self.pipeline.model(**model_inputs, output_hidden_states=True, **kwargs) | |
if self.layer == "mean": | |
token_embeddings = torch.stack(outputs.hidden_states, dim=0).mean(dim=0) | |
else: | |
assert isinstance(self.layer, int) | |
token_embeddings = outputs.hidden_states[self.layer] | |
mask = model_inputs["attention_mask"] | |
sentence_embeddings = torch.sum(token_embeddings * mask.unsqueeze(-1), dim=1) | |
return sentence_embeddings | |