Spaces:
Runtime error
Runtime error
Commit
·
6657904
1
Parent(s):
9c9b15b
update
Browse files- app/model.py +4 -3
app/model.py
CHANGED
@@ -25,12 +25,14 @@ class IndoBERTBiLSTM(PreTrainedModel):
|
|
25 |
def __init__(self, bert_config):
|
26 |
super().__init__(bert_config)
|
27 |
self.output_dim = OUTPUT_DIM
|
|
|
28 |
self.hidden_dim = HIDDEN_DIM
|
29 |
self.bidirectional = BIDIRECTIONAL
|
30 |
|
31 |
self.bert = BertModel.from_pretrained(bert_path)
|
32 |
self.lstm = nn.LSTM(input_size=self.bert.config.hidden_size,
|
33 |
hidden_size=self.hidden_dim,
|
|
|
34 |
bidirectional=self.bidirectional,
|
35 |
batch_first=True)
|
36 |
self.dropout = nn.Dropout(DROPOUT)
|
@@ -39,11 +41,11 @@ class IndoBERTBiLSTM(PreTrainedModel):
|
|
39 |
def forward(self, input_ids, attention_mask):
|
40 |
|
41 |
hidden = self.init_hidden(input_ids.shape[0])
|
42 |
-
# print("hidden : ", type(hidden))
|
43 |
output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
44 |
sequence_output = output.last_hidden_state
|
45 |
|
46 |
lstm_output, (hidden_last, cn_last) = self.lstm(sequence_output, hidden)
|
|
|
47 |
hidden_last_L=hidden_last[-2]
|
48 |
hidden_last_R=hidden_last[-1]
|
49 |
hidden_last_out=torch.cat([hidden_last_L,hidden_last_R],dim=-1) #[16, 1536]
|
@@ -72,5 +74,4 @@ class IndoBERTBiLSTM(PreTrainedModel):
|
|
72 |
weight.new(self.n_layers*number, batch_size, self.hidden_dim).zero_().float()
|
73 |
)
|
74 |
|
75 |
-
return hidden
|
76 |
-
|
|
|
25 |
def __init__(self, bert_config):
|
26 |
super().__init__(bert_config)
|
27 |
self.output_dim = OUTPUT_DIM
|
28 |
+
self.n_layers = 1
|
29 |
self.hidden_dim = HIDDEN_DIM
|
30 |
self.bidirectional = BIDIRECTIONAL
|
31 |
|
32 |
self.bert = BertModel.from_pretrained(bert_path)
|
33 |
self.lstm = nn.LSTM(input_size=self.bert.config.hidden_size,
|
34 |
hidden_size=self.hidden_dim,
|
35 |
+
num_layers=self.n_layers,
|
36 |
bidirectional=self.bidirectional,
|
37 |
batch_first=True)
|
38 |
self.dropout = nn.Dropout(DROPOUT)
|
|
|
41 |
def forward(self, input_ids, attention_mask):
|
42 |
|
43 |
hidden = self.init_hidden(input_ids.shape[0])
|
|
|
44 |
output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
45 |
sequence_output = output.last_hidden_state
|
46 |
|
47 |
lstm_output, (hidden_last, cn_last) = self.lstm(sequence_output, hidden)
|
48 |
+
|
49 |
hidden_last_L=hidden_last[-2]
|
50 |
hidden_last_R=hidden_last[-1]
|
51 |
hidden_last_out=torch.cat([hidden_last_L,hidden_last_R],dim=-1) #[16, 1536]
|
|
|
74 |
weight.new(self.n_layers*number, batch_size, self.hidden_dim).zero_().float()
|
75 |
)
|
76 |
|
77 |
+
return hidden
|
|