zArabi commited on
Commit
28bc44f
·
1 Parent(s): ab3bc41

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +19 -0
model.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertModel
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class SentimentModel(nn.Module):
6
+
7
+ def __init__(self, config):
8
+ super(SentimentModel, self).__init__()
9
+ self.bert = BertModel.from_pretrained(modelName, return_dict=False)
10
+ self.dropout = nn.Dropout(0.3)
11
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
12
+
13
+ def forward(self, input_ids, attention_mask):
14
+ _, pooled_output = self.bert(
15
+ input_ids=input_ids,
16
+ attention_mask=attention_mask)
17
+ pooled_output = self.dropout(pooled_output)
18
+ logits = self.classifier(pooled_output)
19
+ return logits