Tonic commited on
Commit
68ff849
·
unverified ·
1 Parent(s): f3f30d7

add pipeline

Browse files
Files changed (1) hide show
  1. tasks/text.py +12 -24
tasks/text.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict, Tuple
9
  import torch
10
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
11
 
12
  from .utils.evaluation import TextEvaluationRequest
13
  from .utils.emissions import tracker, clean_emissions_data, get_space_info, start_tracking, stop_tracking
@@ -26,14 +26,12 @@ class TextClassifier:
26
  max_retries = 3
27
  for attempt in range(max_retries):
28
  try:
29
- # Load model and tokenizer directly instead of using pipeline
30
- self.model = AutoModelForSequenceClassification.from_pretrained(
31
- "Tonic/climate-guard-toxic-agent"
32
- ).to(self.device)
33
- self.tokenizer = AutoTokenizer.from_pretrained(
34
- "Tonic/climate-guard-toxic-agent"
35
  )
36
- self.model.eval() # Set to evaluation mode
37
  print("Model initialized successfully")
38
  break
39
  except Exception as e:
@@ -45,18 +43,11 @@ class TextClassifier:
45
  def predict_single(self, text: str) -> int:
46
  """Predict single text instance"""
47
  try:
48
- inputs = self.tokenizer(
49
- text,
50
- return_tensors="pt",
51
- truncation=True,
52
- max_length=512,
53
- padding=True
54
- ).to(self.device)
55
-
56
- with torch.no_grad():
57
- outputs = self.model(**inputs)
58
- predictions = outputs.logits.argmax(-1)
59
- return predictions.item()
60
  except Exception as e:
61
  print(f"Error in single prediction: {str(e)}")
62
  return 0 # Return default prediction on error
@@ -114,15 +105,13 @@ async def evaluate_text(request: TextEvaluationRequest):
114
  # Start tracking emissions
115
  start_tracking()
116
 
117
- # tracker.start_task("inference")
118
-
119
  true_labels = test_dataset["label"]
120
 
121
  # Initialize the model once
122
  classifier = TextClassifier()
123
 
124
  # Prepare batches
125
- batch_size = 16 # Reduced batch size for better memory management
126
  quotes = test_dataset["quote"]
127
  num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
128
  batches = [
@@ -162,7 +151,6 @@ async def evaluate_text(request: TextEvaluationRequest):
162
 
163
  # Stop tracking emissions
164
  emissions_data = stop_tracking()
165
- # emissions_data = tracker.stop_task()
166
 
167
  # Calculate accuracy
168
  accuracy = accuracy_score(true_labels, predictions)
 
7
  from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict, Tuple
9
  import torch
10
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
11
 
12
  from .utils.evaluation import TextEvaluationRequest
13
  from .utils.emissions import tracker, clean_emissions_data, get_space_info, start_tracking, stop_tracking
 
26
  max_retries = 3
27
  for attempt in range(max_retries):
28
  try:
29
+ # Initialize using pipeline instead
30
+ self.classifier = pipeline(
31
+ "text-classification",
32
+ model="Tonic/climate-guard-toxic-agent",
33
+ device=self.device
 
34
  )
 
35
  print("Model initialized successfully")
36
  break
37
  except Exception as e:
 
43
  def predict_single(self, text: str) -> int:
44
  """Predict single text instance"""
45
  try:
46
+ result = self.classifier(text)
47
+ # Extract the label index from the result
48
+ # Assuming the model outputs label indices directly
49
+ label = int(result[0]['label'].split('_')[0])
50
+ return label
 
 
 
 
 
 
 
51
  except Exception as e:
52
  print(f"Error in single prediction: {str(e)}")
53
  return 0 # Return default prediction on error
 
105
  # Start tracking emissions
106
  start_tracking()
107
 
 
 
108
  true_labels = test_dataset["label"]
109
 
110
  # Initialize the model once
111
  classifier = TextClassifier()
112
 
113
  # Prepare batches
114
+ batch_size = 16
115
  quotes = test_dataset["quote"]
116
  num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
117
  batches = [
 
151
 
152
  # Stop tracking emissions
153
  emissions_data = stop_tracking()
 
154
 
155
  # Calculate accuracy
156
  accuracy = accuracy_score(true_labels, predictions)