import torch from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification class EnglishScoringModel(torch.nn.Module): def __init__(self, model, tokenizer, device): super().__init__() self.device = device self.model = model.to(self.device) self._tokenizer = tokenizer @staticmethod def load( model_path: str, type: str, state_dict_path: str = None, device="cpu" ) -> "EnglishScoringModel": """ Load the model from the given path and return the model instance. Args: model_path (str): The path to the model. type (str): The type of the model. It should be either 'biencoder' or 'crossencoder'. state_dict_path (str): The path to the state dict. Default is None. device (str): The device to use. Default is 'cpu'. Returns: EnglishScoringModel: The model instance. """ tokenizer = AutoTokenizer.from_pretrained(model_path) if type == "biencoder": model = AutoModel.from_pretrained(model_path) competence_model = BiEncoder(model, tokenizer, device) elif type == "crossencoder": model = AutoModelForSequenceClassification.from_pretrained(model_path) competence_model = CrossEncoder(model, tokenizer, device) else: raise NotImplementedError( "Model type is only implemented for biencoder and crossencoder" ) if state_dict_path: competence_model.load_state_dict(torch.load(state_dict_path), strict=False) return competence_model def save_state_dict(self, state_dict_path: str) -> None: torch.save(self.state_dict(), state_dict_path) def tokenizer(self, *args, **kwargs): """ Tokenize the given arguments and return the tokenized tensors. Default options are padding=True, truncation=True, and return_tensors='pt'. """ kwargs.setdefault("padding", True) kwargs.setdefault("truncation", True) kwargs.setdefault("return_tensors", "pt") return self._tokenizer(*args, **kwargs).to(self.device) def forward(self, *args, type: str = "set", **kwargs) -> torch.Tensor: """ Forward pass of the model. Forward type should be either 'single' or 'set'. Args: type (str): The type of the forward pass. Default is 'set'. """ if type == "single": return self.forward_single(*args, **kwargs) elif type == "set": return self.forward_set(*args, **kwargs) else: raise ValueError("Forward type should be either 'single' or 'set'.") def forward_single( self, transcripts: list[str], competences: list[str], **kwargs ) -> torch.Tensor: """ Forward pass of the model for each transcript and competence pair. Args: transcripts (list[str]): The list of transcripts. competences (list[str]): The list of competences. Returns: torch.Tensor: The predicted probabilities from each pair. """ assert len(transcripts) == len(competences) raise NotImplementedError def forward_set( self, transcripts: list[str], competence_sets: list[list[str]], **kwargs ) -> torch.Tensor: """ Forward pass of the model for each transcript and set of competences. Args: transcripts (list[str]): The list of transcripts. competence_sets (list[list[str]]): The list of sets of competences. Returns: torch.Tensor: The predicted probabilities from each transcript across the set of competences. """ assert len(transcripts) == len(competence_sets) device = self.device lc_list = [len(competences) for competences in competence_sets] max_lc = max(lc_list) flat_t = [t for i, t in enumerate(transcripts) for _ in range(lc_list[i])] flat_c = [c for cs in competence_sets for c in cs] sims = self(flat_t, flat_c, type="single", **kwargs) mask = torch.tensor( [[1] * lc + [0] * (max_lc - lc) for lc in lc_list], device=device, dtype=torch.bool, ) padded = torch.full( (len(lc_list), max_lc), fill_value=float("-inf"), device=device ) idx = 0 for r, lc in enumerate(lc_list): padded[r, :lc] = sims[idx : idx + lc] idx += lc T = 0.30 tau = 0.30 alpha = 12.0 level_logits = padded / T with torch.no_grad(): sim_padded = padded.clone() sim_padded[~mask] = float("-inf") max_sim, _ = sim_padded.max(dim=1) max_sim[max_sim == float("-inf")] = -1.0 none_logit = torch.nn.functional.softplus(alpha * (tau - max_sim)) all_logits = torch.zeros((len(lc_list), 1 + max_lc), device=device) all_logits[:, 0] = none_logit all_logits[:, 1:] = level_logits probs = torch.softmax(all_logits, dim=1) probs[:, 1:][~mask] = 0.0 row_sums = probs.sum(dim=1, keepdim=True).clamp_min(1e-12) probs = probs / row_sums return probs class BiEncoder(EnglishScoringModel): def __init__(self, model, tokenizer, device="cpu"): super().__init__(model, tokenizer, device) def forward_single( self, transcripts: list[str], competences: list[str], tokenizer_padding=True ) -> torch.Tensor: assert len(transcripts) == len(competences) features_t = self.tokenizer(transcripts, padding=tokenizer_padding) features_c = self.tokenizer(competences, padding=tokenizer_padding) embeddings_t = self.model(**features_t) embeddings_c = self.model(**features_c) embeddings_t = self.pooling(embeddings_t, features_t["attention_mask"]) embeddings_c = self.pooling(embeddings_c, features_c["attention_mask"]) prob = torch.nn.functional.cosine_similarity(embeddings_t, embeddings_c, dim=1) prob = torch.clamp(prob, min=1e-20) return prob @staticmethod def pooling(model_output, attention_mask: torch.Tensor) -> torch.Tensor: """ Pool the model output using the attention mask with normalized mean pooling. Args: model_output (torch.Tensor): The model output tensor. attention_mask (torch.Tensor): The attention mask tensor. Returns: torch.Tensor: The pooled embeddings. """ token_embeddings = model_output[0] input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() ) pooled_embeddings = torch.sum( token_embeddings * input_mask_expanded, 1 ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) return torch.nn.functional.normalize(pooled_embeddings, p=2, dim=1) class CrossEncoder(EnglishScoringModel): def __init__(self, model, tokenizer, device="cpu"): super().__init__(model, tokenizer, device) def forward_single( self, transcripts: list[str], competences: list[str], tokenizer_padding=True ) -> torch.Tensor: assert len(transcripts) == len(competences) features = self.tokenizer(transcripts, competences, padding=tokenizer_padding) logits = self.model(**features).logits prob = torch.nn.functional.softmax(logits, dim=1) prob = prob[:, 1] return prob