Spaces:
Sleeping
Sleeping
import torch | |
class IndonesianScoringModel(torch.nn.Module): | |
def __init__(self, model, tokenizer, device): | |
super().__init__() | |
self.device = device | |
self.model = model.to(self.device) | |
self.tokenizer = tokenizer | |
def forward(self, transcripts: list[str], competence_sets: list[list[str]]): | |
N = len(transcripts) | |
lc_list = [len(cs) for cs in competence_sets] | |
max_lc = max(lc_list) | |
flat_t, flat_c = [], [] | |
for t, cs in zip(transcripts, competence_sets): | |
flat_t.extend([t] * len(cs)) | |
flat_c.extend(cs) | |
for t, cs in zip(transcripts, competence_sets): | |
pad = max_lc - len(cs) | |
if pad > 0: | |
flat_t.extend([t] * pad) | |
flat_c.extend(["__PAD__"] * pad) | |
# Tokenize | |
inputs = self.tokenizer( | |
flat_t, | |
flat_c, | |
padding=True, | |
truncation=True, | |
max_length=512, | |
return_tensors="pt", | |
).to(self.device) | |
outputs = self.model(**inputs) | |
ENTAILMENT_IDX = 0 | |
nli_probs = torch.nn.functional.softmax(outputs.logits, dim=1) | |
entailment_prob = nli_probs[:, ENTAILMENT_IDX] | |
entailment_matrix = entailment_prob.view(N, max_lc) | |
mask = torch.zeros((N, max_lc), dtype=torch.bool, device=self.device) | |
for r, k in enumerate(lc_list): | |
if k > 0: | |
mask[r, :k] = True | |
epsilon = 1e-12 | |
T = 0.30 | |
level_logits = torch.log(entailment_matrix.clamp_min(epsilon) / T) | |
level_logits = level_logits.masked_fill(~mask, float("-inf")) | |
tau = 0.30 | |
alpha = 12.0 | |
max_entailment_prob, _ = entailment_matrix.max(dim=1) | |
none_logit = torch.nn.functional.softplus(alpha * (tau - max_entailment_prob)) | |
all_logits = torch.zeros((N, 1 + max_lc), device=self.device) | |
all_logits[:, 0] = none_logit | |
all_logits[:, 1:] = level_logits | |
probs = torch.softmax(all_logits, dim=1) | |
probs[:, 1:][~mask] = 0.0 | |
row_sums = probs.sum(dim=1, keepdim=True).clamp_min(1e-12) | |
probs = probs / row_sums | |
return probs | |