Terry Zhang commited on
Commit
df46342
·
1 Parent(s): c422e81

proper sbert model load

Browse files
Files changed (1) hide show
  1. tasks/text.py +37 -8
tasks/text.py CHANGED
@@ -1,18 +1,21 @@
1
- from fastapi import APIRouter
2
  from datetime import datetime
3
- from datasets import load_dataset
4
- from sklearn.metrics import accuracy_score
5
  import random
6
- from skops.io import load
7
- from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
8
  import torch
 
9
  from torch.utils.data import DataLoader, Dataset
10
- import numpy as np
11
- from accelerate.test_utils.testing import get_backend
 
 
 
 
12
 
13
  from .utils.evaluation import TextEvaluationRequest
14
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
15
  from .utils.text_preprocessor import preprocess
 
16
 
17
  router = APIRouter()
18
 
@@ -27,6 +30,26 @@ models_descriptions = {
27
  "sbert_distilroberta": "Fine-tuned sentence transformer DistilRoBERTa"
28
  }
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  def baseline_model(dataset_length: int):
32
  # Make random predictions (placeholder for actual model inference)
@@ -81,9 +104,15 @@ def bert_classifier(test_dataset: dict, model: str):
81
 
82
  model_repo = f"theterryzhang/frugal_ai_{model}"
83
 
84
- model = AutoModelForSequenceClassification.from_pretrained(model_repo)
85
  tokenizer = AutoTokenizer.from_pretrained(model_repo)
86
 
 
 
 
 
 
 
 
87
  # Use CUDA if available
88
  device, _, _ = get_backend()
89
 
 
 
1
  from datetime import datetime
 
 
2
  import random
3
+
4
+ import numpy as np
5
  import torch
6
+ from torch import nn
7
  from torch.utils.data import DataLoader, Dataset
8
+ from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer
9
+ from fastapi import APIRouter
10
+ from datasets import load_dataset
11
+ from sklearn.metrics import accuracy_score
12
+ from skops.io import load
13
+ from huggingface_hub import PyTorchModelHubMixin
14
 
15
  from .utils.evaluation import TextEvaluationRequest
16
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
17
  from .utils.text_preprocessor import preprocess
18
+ from accelerate.test_utils.testing import get_backend
19
 
20
  router = APIRouter()
21
 
 
30
  "sbert_distilroberta": "Fine-tuned sentence transformer DistilRoBERTa"
31
  }
32
 
33
+ class SentenceBERTMultiClass(nn.Module, PyTorchModelHubMixin):
34
+ def __init__(self, model_name, num_labels=8):
35
+ super().__init__()
36
+ self.sbert = AutoModel.from_pretrained(model_name)
37
+ self.config = self.sbert.config
38
+ self.dropout = nn.Dropout(0.05)
39
+ self.classifier = nn.Linear(self.sbert.config.hidden_size, num_labels)
40
+
41
+ def forward(self, input_ids, attention_mask):
42
+ outputs = self.sbert(input_ids=input_ids, attention_mask=attention_mask)
43
+ if hasattr(outputs, "pooler_output"):
44
+ pooled_output = outputs.pooler_output
45
+ else:
46
+ pooled_output = outputs.last_hidden_state.mean(dim=1)
47
+
48
+ dropout_output = self.dropout(pooled_output)
49
+ logits = self.classifier(dropout_output)
50
+
51
+ return logits
52
+
53
 
54
  def baseline_model(dataset_length: int):
55
  # Make random predictions (placeholder for actual model inference)
 
104
 
105
  model_repo = f"theterryzhang/frugal_ai_{model}"
106
 
 
107
  tokenizer = AutoTokenizer.from_pretrained(model_repo)
108
 
109
+ if model.isin(['bert_base_pruned']):
110
+ model = AutoModelForSequenceClassification.from_pretrained(model_repo)
111
+ elif model.isin(['sbert_distilroberta']):
112
+ model = SentenceBERTMultiClass.from_pretrained(model_repo)
113
+ else:
114
+ raise(ValueError)
115
+
116
  # Use CUDA if available
117
  device, _, _ = get_backend()
118