Terry Zhang commited on
Commit
a9f8367
·
1 Parent(s): 2b85173

add bert model code

Browse files
Files changed (1) hide show
  1. tasks/text.py +58 -5
tasks/text.py CHANGED
@@ -4,8 +4,11 @@ from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
  import random
6
  from skops.io import load
7
- # Textpreprocessor defined in this scope
8
-
 
 
 
9
 
10
  from .utils.evaluation import TextEvaluationRequest
11
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
@@ -19,11 +22,10 @@ ROUTE = "/text"
19
  models_descriptions = {
20
  "baseline": "random baseline",
21
  "tfidf_xgb": "TF-IDF vectorizer and XGBoost classifier",
 
22
  }
23
 
24
 
25
- # Some code borrowed from Nonnormalizable
26
-
27
  def baseline_model(dataset_length: int):
28
  # Make random predictions (placeholder for actual model inference)
29
  predictions = [random.randint(0, 7) for _ in range(dataset_length)]
@@ -48,10 +50,59 @@ def tree_classifier(test_dataset: dict, model: str):
48
 
49
  return predictions
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  @router.post(ROUTE, tags=["Text Task"])
53
  async def evaluate_text(request: TextEvaluationRequest,
54
- model: str = "tfidf_xgb"):
55
  """
56
  Evaluate text classification for climate disinformation detection.
57
 
@@ -100,6 +151,8 @@ async def evaluate_text(request: TextEvaluationRequest,
100
  predictions = baseline_model(len(true_labels))
101
  elif model == "tfidf_xgb":
102
  predictions = tree_classifier(test_dataset, model='xgb_pipeline')
 
 
103
 
104
  #--------------------------------------------------------------------------------------------
105
  # YOUR MODEL INFERENCE STOPS HERE
 
4
  from sklearn.metrics import accuracy_score
5
  import random
6
  from skops.io import load
7
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
8
+ import torch
9
+ from torch.utils.data import DataLoader, Dataset
10
+ import numpy as np
11
+ from accelerate.test_utils.testing import get_backend
12
 
13
  from .utils.evaluation import TextEvaluationRequest
14
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
 
22
  models_descriptions = {
23
  "baseline": "random baseline",
24
  "tfidf_xgb": "TF-IDF vectorizer and XGBoost classifier",
25
+ "bert_base_pruned": "Pruned BERT base model",
26
  }
27
 
28
 
 
 
29
  def baseline_model(dataset_length: int):
30
  # Make random predictions (placeholder for actual model inference)
31
  predictions = [random.randint(0, 7) for _ in range(dataset_length)]
 
50
 
51
  return predictions
52
 
53
+ class TextDataset(Dataset):
54
+ def __init__(self, texts, tokenizer, max_length=256):
55
+ self.texts = texts
56
+ self.tokenized_texts = tokenizer(
57
+ texts,
58
+ truncation=True,
59
+ padding=True,
60
+ max_length=max_length,
61
+ return_tensors="pt",
62
+ )
63
+
64
+ def __getitem__(self, idx):
65
+ item = {key: val[idx] for key, val in self.tokenized_texts.items()}
66
+ return item
67
+
68
+ def __len__(self) -> int:
69
+ return len(self.texts)
70
+
71
+
72
+
73
+ def bert_classifier(test_dataset: dict, model: str):
74
+ texts = test_dataset["quote"]
75
+
76
+ model_repo = f"theterryzhang/frugal_ai_{model}"
77
+
78
+ model = AutoModelForSequenceClassification.from_pretrained(model_repo)
79
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
80
+
81
+ # Use CUDA if available
82
+ device, _, _ = get_backend()
83
+
84
+ model = model.to(device)
85
+
86
+ # Prepare dataset
87
+ dataset = TextDataset(texts, tokenizer=tokenizer)
88
+ dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
89
+
90
+ model.eval()
91
+ with torch.no_grad():
92
+ predictions = np.array([])
93
+ for batch in dataloader:
94
+ test_input_ids = batch["input_ids"].to(device)
95
+ test_attention_mask = batch["attention_mask"].to(device)
96
+ outputs = model(test_input_ids, test_attention_mask)
97
+ p = torch.argmax(outputs.logits, dim=1)
98
+ predictions = np.append(predictions, p.cpu().numpy())
99
+
100
+ return predictions
101
+
102
 
103
  @router.post(ROUTE, tags=["Text Task"])
104
  async def evaluate_text(request: TextEvaluationRequest,
105
+ model: str = "bert_base_pruned"):
106
  """
107
  Evaluate text classification for climate disinformation detection.
108
 
 
151
  predictions = baseline_model(len(true_labels))
152
  elif model == "tfidf_xgb":
153
  predictions = tree_classifier(test_dataset, model='xgb_pipeline')
154
+ elif 'bert' in model:
155
+ predictions = bert_classifier(test_dataset, model)
156
 
157
  #--------------------------------------------------------------------------------------------
158
  # YOUR MODEL INFERENCE STOPS HERE