ScientificArgumentRecommender / src /models /sequence_classification_with_pooler.py
ArneBinder's picture
update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e verified
import abc
import logging
from typing import Callable, List, Optional
import torch
import torch.nn.functional as F
from adapters import AutoAdapterModel
from pie_modules.models import SequencePairSimilarityModelWithPooler
from pie_modules.models.sequence_classification_with_pooler import (
InputType,
OutputType,
SequenceClassificationModelWithPooler,
SequenceClassificationModelWithPoolerBase,
TargetType,
separate_arguments_by_prefix,
)
from pytorch_ie import PyTorchIEModel
from torch import FloatTensor, Tensor
from transformers import AutoConfig, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
logger = logging.getLogger(__name__)
class SequenceClassificationModelWithPoolerAndAdapterBase(
SequenceClassificationModelWithPoolerBase, abc.ABC
):
def __init__(self, adapter_name_or_path: Optional[str] = None, **kwargs):
self.adapter_name_or_path = adapter_name_or_path
super().__init__(**kwargs)
def setup_base_model(self) -> PreTrainedModel:
if self.adapter_name_or_path is None:
return super().setup_base_model()
else:
config = AutoConfig.from_pretrained(self.model_name_or_path)
if self.is_from_pretrained:
model = AutoAdapterModel.from_config(config=config)
else:
model = AutoAdapterModel.from_pretrained(self.model_name_or_path, config=config)
# load the adapter in any case (it looks like it is not saved in the state or loaded
# from a serialized state)
logger.info(f"load adapter: {self.adapter_name_or_path}")
model.load_adapter(self.adapter_name_or_path, source="hf", set_active=True)
return model
@PyTorchIEModel.register()
class SequencePairSimilarityModelWithPoolerAndAdapter(
SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
):
pass
@PyTorchIEModel.register()
class SequenceClassificationModelWithPoolerAndAdapter(
SequenceClassificationModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
):
pass
def get_max_cosine_sim(embeddings: Tensor, embeddings_pair: Tensor) -> Tensor:
# Normalize the embeddings
embeddings_normalized = F.normalize(embeddings, p=2, dim=1) # Shape: (n, k)
embeddings_normalized_pair = F.normalize(embeddings_pair, p=2, dim=1) # Shape: (m, k)
# Compute the cosine similarity matrix
cosine_sim = torch.mm(embeddings_normalized, embeddings_normalized_pair.T) # Shape: (n, m)
# Get the overall maximum cosine similarity value
max_cosine_sim = torch.max(cosine_sim) # This will return a scalar
return max_cosine_sim
def get_span_embeddings(
embeddings: FloatTensor, start_indices: Tensor, end_indices: Tensor
) -> List[FloatTensor]:
result = []
for embeds, starts, ends in zip(embeddings, start_indices, end_indices):
span_embeds = embeds[starts[0] : ends[0]]
result.append(span_embeds)
return result
@PyTorchIEModel.register()
class SequencePairSimilarityModelWithMaxCosineSim(SequencePairSimilarityModelWithPooler):
def get_pooled_output(self, model_inputs, pooler_inputs) -> List[FloatTensor]:
output = self.model(**model_inputs)
hidden_state = output.last_hidden_state
# pooled_output = self.pooler(hidden_state, **pooler_inputs)
# pooled_output = self.dropout(pooled_output)
span_embeds = get_span_embeddings(hidden_state, **pooler_inputs)
return span_embeds
def forward(
self,
inputs: InputType,
targets: Optional[TargetType] = None,
return_hidden_states: bool = False,
) -> OutputType:
sanitized_inputs = separate_arguments_by_prefix(
# Note that the order of the prefixes is important because one is a prefix of the other,
# so we need to start with the longer!
arguments=inputs,
prefixes=["pooler_pair_", "pooler_"],
)
span_embeddings = self.get_pooled_output(
model_inputs=sanitized_inputs["remaining"]["encoding"],
pooler_inputs=sanitized_inputs["pooler_"],
)
span_embeddings_pair = self.get_pooled_output(
model_inputs=sanitized_inputs["remaining"]["encoding_pair"],
pooler_inputs=sanitized_inputs["pooler_pair_"],
)
logits_list = [
get_max_cosine_sim(span_embeds, span_embeds_pair)
for span_embeds, span_embeds_pair in zip(span_embeddings, span_embeddings_pair)
]
logits = torch.stack(logits_list)
result = {"logits": logits}
if targets is not None:
labels = targets["scores"]
loss = self.loss_fct(logits, labels)
result["loss"] = loss
if return_hidden_states:
raise NotImplementedError("return_hidden_states is not yet implemented")
return SequenceClassifierOutput(**result)
@PyTorchIEModel.register()
class SequencePairSimilarityModelWithMaxCosineSimAndAdapter(
SequencePairSimilarityModelWithMaxCosineSim, SequencePairSimilarityModelWithPoolerAndAdapter
):
pass
@PyTorchIEModel.register()
class SequencePairSimilarityModelDummy(SequencePairSimilarityModelWithPooler):
def __init__(
self,
method: str = "random",
random_seed: Optional[int] = None,
**kwargs,
):
self.method = method
self.random_seed = random_seed
super().__init__(**kwargs)
def setup_classifier(
self, pooler_output_dim: int
) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]:
if self.method == "random":
generator = torch.Generator(device=self.device)
if self.random_seed is not None:
generator = generator.manual_seed(self.random_seed)
def binary_classify_random(
inputs: torch.FloatTensor,
inputs_pair: torch.FloatTensor,
) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]:
"""Randomly classifies pairs of inputs as similar or not similar."""
# Generate random logits in the range of [0, 1]
logits = torch.rand(inputs.size(0), device=self.device, generator=generator)
return logits
return binary_classify_random
elif self.method == "zero":
def binary_classify_zero(
inputs: torch.FloatTensor,
inputs_pair: torch.FloatTensor,
) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]:
"""Classifies pairs of inputs as not similar (logit = 0)."""
# Return a tensor of zeros with the same batch size
logits = torch.zeros(inputs.size(0), device=self.device)
return logits
return binary_classify_zero
else:
raise ValueError(
f"Unknown method: {self.method}. Supported methods are 'random' and 'zero'."
)
def setup_loss_fct(self) -> Callable:
def loss_fct(logits: FloatTensor, labels: FloatTensor) -> FloatTensor:
raise NotImplementedError(
"Dummy model does not support loss function, as it is not used for training."
)
return loss_fct
def get_pooled_output(self, model_inputs, pooler_inputs) -> torch.FloatTensor:
# Just return a tensor of zeros in the shape of the batch size
# so that the classifier can construct dummy logits in the correct shape.
bs = pooler_inputs["start_indices"].size(0)
return torch.zeros(bs, device=self.device)