kadabengaran commited on
Commit
6657904
·
1 Parent(s): 9c9b15b
Files changed (1) hide show
  1. app/model.py +4 -3
app/model.py CHANGED
@@ -25,12 +25,14 @@ class IndoBERTBiLSTM(PreTrainedModel):
25
  def __init__(self, bert_config):
26
  super().__init__(bert_config)
27
  self.output_dim = OUTPUT_DIM
 
28
  self.hidden_dim = HIDDEN_DIM
29
  self.bidirectional = BIDIRECTIONAL
30
 
31
  self.bert = BertModel.from_pretrained(bert_path)
32
  self.lstm = nn.LSTM(input_size=self.bert.config.hidden_size,
33
  hidden_size=self.hidden_dim,
 
34
  bidirectional=self.bidirectional,
35
  batch_first=True)
36
  self.dropout = nn.Dropout(DROPOUT)
@@ -39,11 +41,11 @@ class IndoBERTBiLSTM(PreTrainedModel):
39
  def forward(self, input_ids, attention_mask):
40
 
41
  hidden = self.init_hidden(input_ids.shape[0])
42
- # print("hidden : ", type(hidden))
43
  output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
44
  sequence_output = output.last_hidden_state
45
 
46
  lstm_output, (hidden_last, cn_last) = self.lstm(sequence_output, hidden)
 
47
  hidden_last_L=hidden_last[-2]
48
  hidden_last_R=hidden_last[-1]
49
  hidden_last_out=torch.cat([hidden_last_L,hidden_last_R],dim=-1) #[16, 1536]
@@ -72,5 +74,4 @@ class IndoBERTBiLSTM(PreTrainedModel):
72
  weight.new(self.n_layers*number, batch_size, self.hidden_dim).zero_().float()
73
  )
74
 
75
- return hidden
76
-
 
25
  def __init__(self, bert_config):
26
  super().__init__(bert_config)
27
  self.output_dim = OUTPUT_DIM
28
+ self.n_layers = 1
29
  self.hidden_dim = HIDDEN_DIM
30
  self.bidirectional = BIDIRECTIONAL
31
 
32
  self.bert = BertModel.from_pretrained(bert_path)
33
  self.lstm = nn.LSTM(input_size=self.bert.config.hidden_size,
34
  hidden_size=self.hidden_dim,
35
+ num_layers=self.n_layers,
36
  bidirectional=self.bidirectional,
37
  batch_first=True)
38
  self.dropout = nn.Dropout(DROPOUT)
 
41
  def forward(self, input_ids, attention_mask):
42
 
43
  hidden = self.init_hidden(input_ids.shape[0])
 
44
  output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
45
  sequence_output = output.last_hidden_state
46
 
47
  lstm_output, (hidden_last, cn_last) = self.lstm(sequence_output, hidden)
48
+
49
  hidden_last_L=hidden_last[-2]
50
  hidden_last_R=hidden_last[-1]
51
  hidden_last_out=torch.cat([hidden_last_L,hidden_last_R],dim=-1) #[16, 1536]
 
74
  weight.new(self.n_layers*number, batch_size, self.hidden_dim).zero_().float()
75
  )
76
 
77
+ return hidden