File size: 2,208 Bytes
bf3184f
 
 
 
 
 
 
 
 
 
 
3515f11
 
 
bf3184f
3515f11
 
 
 
bf3184f
3515f11
 
 
 
 
bf3184f
 
 
3515f11
 
bf3184f
 
 
 
 
 
 
 
3515f11
bf3184f
3515f11
 
bf3184f
3515f11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf3184f
3515f11
bf3184f
3515f11
 
 
bf3184f
3515f11
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch


class IndonesianScoringModel(torch.nn.Module):
    def __init__(self, model, tokenizer, device):
        super().__init__()
        self.device = device
        self.model = model.to(self.device)
        self.tokenizer = tokenizer

    def forward(self, transcripts: list[str], competence_sets: list[list[str]]):
        N = len(transcripts)
        lc_list = [len(cs) for cs in competence_sets]
        max_lc = max(lc_list)

        flat_t, flat_c = [], []
        for t, cs in zip(transcripts, competence_sets):
            flat_t.extend([t] * len(cs))
            flat_c.extend(cs)

        for t, cs in zip(transcripts, competence_sets):
            pad = max_lc - len(cs)
            if pad > 0:
                flat_t.extend([t] * pad)
                flat_c.extend(["__PAD__"] * pad)

        # Tokenize
        inputs = self.tokenizer(
            flat_t,
            flat_c,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt",
        ).to(self.device)

        outputs = self.model(**inputs)

        ENTAILMENT_IDX = 0

        nli_probs = torch.nn.functional.softmax(outputs.logits, dim=1)
        entailment_prob = nli_probs[:, ENTAILMENT_IDX]

        entailment_matrix = entailment_prob.view(N, max_lc)

        mask = torch.zeros((N, max_lc), dtype=torch.bool, device=self.device)
        for r, k in enumerate(lc_list):
            if k > 0:
                mask[r, :k] = True

        epsilon = 1e-12
        T = 0.30
        level_logits = torch.log(entailment_matrix.clamp_min(epsilon) / T)
        level_logits = level_logits.masked_fill(~mask, float("-inf"))

        tau = 0.30
        alpha = 12.0
        max_entailment_prob, _ = entailment_matrix.max(dim=1)
        none_logit = torch.nn.functional.softplus(alpha * (tau - max_entailment_prob))

        all_logits = torch.zeros((N, 1 + max_lc), device=self.device)
        all_logits[:, 0] = none_logit
        all_logits[:, 1:] = level_logits

        probs = torch.softmax(all_logits, dim=1)

        probs[:, 1:][~mask] = 0.0
        row_sums = probs.sum(dim=1, keepdim=True).clamp_min(1e-12)
        probs = probs / row_sums

        return probs