ThanhDT127 commited on
Commit
c017e96
·
1 Parent(s): eb16c9f

update main

Browse files
Files changed (1) hide show
  1. main.py +7 -7
main.py CHANGED
@@ -50,7 +50,7 @@ class TextInput(BaseModel):
50
  text: str
51
 
52
  class BertBiLSTMClassifier(nn.Module):
53
- def __init__(self, bert_model_name, num_emotion_classes, binary_cols, lstm_hidden_size=256, dropout=0.3):
54
  super().__init__()
55
  self.bert = AutoModel.from_pretrained(bert_model_name)
56
  self.lstm = nn.LSTM(
@@ -93,7 +93,7 @@ model = BertBiLSTMClassifier(
93
  bert_model_name="vinai/phobert-base",
94
  num_emotion_classes=3,
95
  binary_cols=binary_cols,
96
- lstm_hidden_size=256
97
  ).to(device)
98
 
99
  # Load model state dict
@@ -101,11 +101,11 @@ model.load_state_dict(model_state_dict)
101
  model.eval()
102
 
103
  threshold_dict = {
104
- 'sản phẩm': 0.6,
105
- 'giá cả': 0.4,
106
- 'vận chuyển': 0.45,
107
- 'thái độ và dịch vụ khách hàng': 0.35,
108
- 'khác': 0.4
109
  }
110
 
111
  def predict(text: str):
 
50
  text: str
51
 
52
  class BertBiLSTMClassifier(nn.Module):
53
+ def __init__(self, bert_model_name, num_emotion_classes, binary_cols, lstm_hidden_size=128, dropout=0.4):
54
  super().__init__()
55
  self.bert = AutoModel.from_pretrained(bert_model_name)
56
  self.lstm = nn.LSTM(
 
93
  bert_model_name="vinai/phobert-base",
94
  num_emotion_classes=3,
95
  binary_cols=binary_cols,
96
+ lstm_hidden_size=128
97
  ).to(device)
98
 
99
  # Load model state dict
 
101
  model.eval()
102
 
103
  threshold_dict = {
104
+ 'sản phẩm': 0.28,
105
+ 'giá cả': 0.58,
106
+ 'vận chuyển': 0.58,
107
+ 'thái độ và dịch vụ khách hàng': 0.70,
108
+ 'khác': 0.6
109
  }
110
 
111
  def predict(text: str):