video-scoring / models /indonesian_scoring.py
bearking58's picture
feat: level 0 inference + improve model accuracy
3515f11
import torch
class IndonesianScoringModel(torch.nn.Module):
def __init__(self, model, tokenizer, device):
super().__init__()
self.device = device
self.model = model.to(self.device)
self.tokenizer = tokenizer
def forward(self, transcripts: list[str], competence_sets: list[list[str]]):
N = len(transcripts)
lc_list = [len(cs) for cs in competence_sets]
max_lc = max(lc_list)
flat_t, flat_c = [], []
for t, cs in zip(transcripts, competence_sets):
flat_t.extend([t] * len(cs))
flat_c.extend(cs)
for t, cs in zip(transcripts, competence_sets):
pad = max_lc - len(cs)
if pad > 0:
flat_t.extend([t] * pad)
flat_c.extend(["__PAD__"] * pad)
# Tokenize
inputs = self.tokenizer(
flat_t,
flat_c,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt",
).to(self.device)
outputs = self.model(**inputs)
ENTAILMENT_IDX = 0
nli_probs = torch.nn.functional.softmax(outputs.logits, dim=1)
entailment_prob = nli_probs[:, ENTAILMENT_IDX]
entailment_matrix = entailment_prob.view(N, max_lc)
mask = torch.zeros((N, max_lc), dtype=torch.bool, device=self.device)
for r, k in enumerate(lc_list):
if k > 0:
mask[r, :k] = True
epsilon = 1e-12
T = 0.30
level_logits = torch.log(entailment_matrix.clamp_min(epsilon) / T)
level_logits = level_logits.masked_fill(~mask, float("-inf"))
tau = 0.30
alpha = 12.0
max_entailment_prob, _ = entailment_matrix.max(dim=1)
none_logit = torch.nn.functional.softplus(alpha * (tau - max_entailment_prob))
all_logits = torch.zeros((N, 1 + max_lc), device=self.device)
all_logits[:, 0] = none_logit
all_logits[:, 1:] = level_logits
probs = torch.softmax(all_logits, dim=1)
probs[:, 1:][~mask] = 0.0
row_sums = probs.sum(dim=1, keepdim=True).clamp_min(1e-12)
probs = probs / row_sums
return probs