import torch class IndonesianScoringModel(torch.nn.Module): def __init__(self, model, tokenizer, device): super().__init__() self.device = device self.model = model.to(self.device) self.tokenizer = tokenizer def forward(self, transcripts: list[str], competence_sets: list[list[str]]): N = len(transcripts) lc_list = [len(cs) for cs in competence_sets] max_lc = max(lc_list) flat_t, flat_c = [], [] for t, cs in zip(transcripts, competence_sets): flat_t.extend([t] * len(cs)) flat_c.extend(cs) for t, cs in zip(transcripts, competence_sets): pad = max_lc - len(cs) if pad > 0: flat_t.extend([t] * pad) flat_c.extend(["__PAD__"] * pad) # Tokenize inputs = self.tokenizer( flat_t, flat_c, padding=True, truncation=True, max_length=512, return_tensors="pt", ).to(self.device) outputs = self.model(**inputs) ENTAILMENT_IDX = 0 nli_probs = torch.nn.functional.softmax(outputs.logits, dim=1) entailment_prob = nli_probs[:, ENTAILMENT_IDX] entailment_matrix = entailment_prob.view(N, max_lc) mask = torch.zeros((N, max_lc), dtype=torch.bool, device=self.device) for r, k in enumerate(lc_list): if k > 0: mask[r, :k] = True epsilon = 1e-12 T = 0.30 level_logits = torch.log(entailment_matrix.clamp_min(epsilon) / T) level_logits = level_logits.masked_fill(~mask, float("-inf")) tau = 0.30 alpha = 12.0 max_entailment_prob, _ = entailment_matrix.max(dim=1) none_logit = torch.nn.functional.softplus(alpha * (tau - max_entailment_prob)) all_logits = torch.zeros((N, 1 + max_lc), device=self.device) all_logits[:, 0] = none_logit all_logits[:, 1:] = level_logits probs = torch.softmax(all_logits, dim=1) probs[:, 1:][~mask] = 0.0 row_sums = probs.sum(dim=1, keepdim=True).clamp_min(1e-12) probs = probs / row_sums return probs