File size: 2,086 Bytes
6eb199d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel

class DeBERTaLSTMClassifier(nn.Module):
    def __init__(self, hidden_dim=128, num_labels=2):
        super().__init__()

        self.deberta = AutoModel.from_pretrained("microsoft/deberta-base")
        for param in self.deberta.parameters():
            param.requires_grad = False  # freeze DeBERTa (as we don't have enough resources, we will not train DeBERTa in this model)

        self.lstm = nn.LSTM(
            input_size=self.deberta.config.hidden_size,
            hidden_size=hidden_dim,
            batch_first=True,
            bidirectional=True
        )

        self.fc = nn.Linear(hidden_dim * 2, num_labels)
        
        # Attention layer để tính token importance
        self.attention = nn.Linear(hidden_dim * 2, 1)

    def forward(self, input_ids, attention_mask, return_attention=False):
        with torch.no_grad():
            outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask, output_attentions=True)

        lstm_out, _ = self.lstm(outputs.last_hidden_state)  # shape: [batch, seq_len, hidden*2]
        
        if return_attention:
            # Tính attention weights cho từng token
            attention_weights = self.attention(lstm_out)  # [batch, seq_len, 1]
            attention_weights = F.softmax(attention_weights.squeeze(-1), dim=-1)  # [batch, seq_len]
            
            # Apply attention mask
            attention_weights = attention_weights * attention_mask.float()
            attention_weights = attention_weights / (attention_weights.sum(dim=-1, keepdim=True) + 1e-8)
            
            # Weighted sum of LSTM outputs
            attended_output = torch.sum(lstm_out * attention_weights.unsqueeze(-1), dim=1)
            logits = self.fc(attended_output)
            
            return logits, attention_weights, outputs.attentions
        else:
            final_hidden = lstm_out[:, -1, :]  # last token output
            logits = self.fc(final_hidden)
            return logits