update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
import abc | |
import logging | |
from typing import Callable, List, Optional | |
import torch | |
import torch.nn.functional as F | |
from adapters import AutoAdapterModel | |
from pie_modules.models import SequencePairSimilarityModelWithPooler | |
from pie_modules.models.sequence_classification_with_pooler import ( | |
InputType, | |
OutputType, | |
SequenceClassificationModelWithPooler, | |
SequenceClassificationModelWithPoolerBase, | |
TargetType, | |
separate_arguments_by_prefix, | |
) | |
from pytorch_ie import PyTorchIEModel | |
from torch import FloatTensor, Tensor | |
from transformers import AutoConfig, PreTrainedModel | |
from transformers.modeling_outputs import SequenceClassifierOutput | |
logger = logging.getLogger(__name__) | |
class SequenceClassificationModelWithPoolerAndAdapterBase( | |
SequenceClassificationModelWithPoolerBase, abc.ABC | |
): | |
def __init__(self, adapter_name_or_path: Optional[str] = None, **kwargs): | |
self.adapter_name_or_path = adapter_name_or_path | |
super().__init__(**kwargs) | |
def setup_base_model(self) -> PreTrainedModel: | |
if self.adapter_name_or_path is None: | |
return super().setup_base_model() | |
else: | |
config = AutoConfig.from_pretrained(self.model_name_or_path) | |
if self.is_from_pretrained: | |
model = AutoAdapterModel.from_config(config=config) | |
else: | |
model = AutoAdapterModel.from_pretrained(self.model_name_or_path, config=config) | |
# load the adapter in any case (it looks like it is not saved in the state or loaded | |
# from a serialized state) | |
logger.info(f"load adapter: {self.adapter_name_or_path}") | |
model.load_adapter(self.adapter_name_or_path, source="hf", set_active=True) | |
return model | |
class SequencePairSimilarityModelWithPoolerAndAdapter( | |
SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase | |
): | |
pass | |
class SequenceClassificationModelWithPoolerAndAdapter( | |
SequenceClassificationModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase | |
): | |
pass | |
def get_max_cosine_sim(embeddings: Tensor, embeddings_pair: Tensor) -> Tensor: | |
# Normalize the embeddings | |
embeddings_normalized = F.normalize(embeddings, p=2, dim=1) # Shape: (n, k) | |
embeddings_normalized_pair = F.normalize(embeddings_pair, p=2, dim=1) # Shape: (m, k) | |
# Compute the cosine similarity matrix | |
cosine_sim = torch.mm(embeddings_normalized, embeddings_normalized_pair.T) # Shape: (n, m) | |
# Get the overall maximum cosine similarity value | |
max_cosine_sim = torch.max(cosine_sim) # This will return a scalar | |
return max_cosine_sim | |
def get_span_embeddings( | |
embeddings: FloatTensor, start_indices: Tensor, end_indices: Tensor | |
) -> List[FloatTensor]: | |
result = [] | |
for embeds, starts, ends in zip(embeddings, start_indices, end_indices): | |
span_embeds = embeds[starts[0] : ends[0]] | |
result.append(span_embeds) | |
return result | |
class SequencePairSimilarityModelWithMaxCosineSim(SequencePairSimilarityModelWithPooler): | |
def get_pooled_output(self, model_inputs, pooler_inputs) -> List[FloatTensor]: | |
output = self.model(**model_inputs) | |
hidden_state = output.last_hidden_state | |
# pooled_output = self.pooler(hidden_state, **pooler_inputs) | |
# pooled_output = self.dropout(pooled_output) | |
span_embeds = get_span_embeddings(hidden_state, **pooler_inputs) | |
return span_embeds | |
def forward( | |
self, | |
inputs: InputType, | |
targets: Optional[TargetType] = None, | |
return_hidden_states: bool = False, | |
) -> OutputType: | |
sanitized_inputs = separate_arguments_by_prefix( | |
# Note that the order of the prefixes is important because one is a prefix of the other, | |
# so we need to start with the longer! | |
arguments=inputs, | |
prefixes=["pooler_pair_", "pooler_"], | |
) | |
span_embeddings = self.get_pooled_output( | |
model_inputs=sanitized_inputs["remaining"]["encoding"], | |
pooler_inputs=sanitized_inputs["pooler_"], | |
) | |
span_embeddings_pair = self.get_pooled_output( | |
model_inputs=sanitized_inputs["remaining"]["encoding_pair"], | |
pooler_inputs=sanitized_inputs["pooler_pair_"], | |
) | |
logits_list = [ | |
get_max_cosine_sim(span_embeds, span_embeds_pair) | |
for span_embeds, span_embeds_pair in zip(span_embeddings, span_embeddings_pair) | |
] | |
logits = torch.stack(logits_list) | |
result = {"logits": logits} | |
if targets is not None: | |
labels = targets["scores"] | |
loss = self.loss_fct(logits, labels) | |
result["loss"] = loss | |
if return_hidden_states: | |
raise NotImplementedError("return_hidden_states is not yet implemented") | |
return SequenceClassifierOutput(**result) | |
class SequencePairSimilarityModelWithMaxCosineSimAndAdapter( | |
SequencePairSimilarityModelWithMaxCosineSim, SequencePairSimilarityModelWithPoolerAndAdapter | |
): | |
pass | |
class SequencePairSimilarityModelDummy(SequencePairSimilarityModelWithPooler): | |
def __init__( | |
self, | |
method: str = "random", | |
random_seed: Optional[int] = None, | |
**kwargs, | |
): | |
self.method = method | |
self.random_seed = random_seed | |
super().__init__(**kwargs) | |
def setup_classifier( | |
self, pooler_output_dim: int | |
) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]: | |
if self.method == "random": | |
generator = torch.Generator(device=self.device) | |
if self.random_seed is not None: | |
generator = generator.manual_seed(self.random_seed) | |
def binary_classify_random( | |
inputs: torch.FloatTensor, | |
inputs_pair: torch.FloatTensor, | |
) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]: | |
"""Randomly classifies pairs of inputs as similar or not similar.""" | |
# Generate random logits in the range of [0, 1] | |
logits = torch.rand(inputs.size(0), device=self.device, generator=generator) | |
return logits | |
return binary_classify_random | |
elif self.method == "zero": | |
def binary_classify_zero( | |
inputs: torch.FloatTensor, | |
inputs_pair: torch.FloatTensor, | |
) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]: | |
"""Classifies pairs of inputs as not similar (logit = 0).""" | |
# Return a tensor of zeros with the same batch size | |
logits = torch.zeros(inputs.size(0), device=self.device) | |
return logits | |
return binary_classify_zero | |
else: | |
raise ValueError( | |
f"Unknown method: {self.method}. Supported methods are 'random' and 'zero'." | |
) | |
def setup_loss_fct(self) -> Callable: | |
def loss_fct(logits: FloatTensor, labels: FloatTensor) -> FloatTensor: | |
raise NotImplementedError( | |
"Dummy model does not support loss function, as it is not used for training." | |
) | |
return loss_fct | |
def get_pooled_output(self, model_inputs, pooler_inputs) -> torch.FloatTensor: | |
# Just return a tensor of zeros in the shape of the batch size | |
# so that the classifier can construct dummy logits in the correct shape. | |
bs = pooler_inputs["start_indices"].size(0) | |
return torch.zeros(bs, device=self.device) | |