File size: 3,591 Bytes
5225d97
 
 
 
03a22c3
5225d97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03a22c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from transformers import RobertaModel, AutoTokenizer
from transformers.modeling_outputs import SequenceClassifierOutput
from huggingface_hub import PyTorchModelHubMixin
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
import torch.nn as nn
import torch

class SentenceBERTClassifier(nn.Module, PyTorchModelHubMixin):
    def __init__(self, model_name="sentence-transformers/all-distilroberta-v1", num_labels=8):
        super().__init__()
        self.sbert = RobertaModel.from_pretrained(model_name)
        self.config = self.sbert.config
        self.config.num_labels = num_labels
        self.dropout = nn.Dropout(0.05)
        self.config.classifier_dropout = 0.05
        self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.sbert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        dropout_output = self.dropout(pooled_output)
        logits = self.classifier(dropout_output)

        return SequenceClassifierOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


class DenseBlock(nn.Module):
    def __init__(self, input_size, output_size, dropout_rate):
        super(DenseBlock, self).__init__()
        self.linear = nn.Linear(input_size, output_size)
        self.batch_norm = nn.BatchNorm1d(output_size)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, input):
        output = self.linear(input)
        output = self.batch_norm(output)
        output = self.activation(output)
        output = self.dropout(output)
        return output

class FeedForwardExpert(nn.Module):
    def __init__(self, dropout_rate, num_labels=8):
        super(FeedForwardExpert, self).__init__()

        # Define the dense blocks
        self.block_1 = DenseBlock(768, 400, dropout_rate)
        self.block_2 = DenseBlock(400, 200, dropout_rate)
        self.final_layer = nn.Linear(200, num_labels)

        self.initialize_weights()

    def forward(self, input):
        output = self.block_1(input) 
        output = self.block_2(output)
        output = self.final_layer(output)

        return output

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)


class MoEClassifier(nn.Module):
    def __init__(self, num_experts, dropout_rate=0.1, gate_hidden_size = 128):
      super(MoEClassifier, self).__init__()
      self.dropout = dropout_rate
      self.num_experts = num_experts
      self.gate_hidden_size = gate_hidden_size

      # Create a list of feedforward experts
      self.experts = nn.ModuleList([FeedForwardExpert(self.dropout) for _ in range(self.num_experts)])

      # A gating network
      self.gate_fc1 = nn.Linear(768, self.gate_hidden_size)
      self.gate_fc2 = nn.Linear(self.gate_hidden_size, self.num_experts)

    def forward(self, x):

      # Calculate gating weights
      gate_hidden = F.relu(self.gate_fc1(x))
      weights = F.softmax(self.gate_fc2(gate_hidden), dim=1).unsqueeze(2)

      # Get outputs from all experts
      outputs = torch.stack([expert(x) for expert in self.experts], dim=2)

      # apply weights using a batch matrix multiplication
      weighted_outputs = torch.bmm(outputs, weights).squeeze(2)

      return weighted_outputs