Spaces:
Sleeping
Sleeping
add pipeline
Browse files- tasks/text.py +12 -24
tasks/text.py
CHANGED
@@ -7,7 +7,7 @@ import os
|
|
7 |
from concurrent.futures import ThreadPoolExecutor
|
8 |
from typing import List, Dict, Tuple
|
9 |
import torch
|
10 |
-
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
11 |
|
12 |
from .utils.evaluation import TextEvaluationRequest
|
13 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info, start_tracking, stop_tracking
|
@@ -26,14 +26,12 @@ class TextClassifier:
|
|
26 |
max_retries = 3
|
27 |
for attempt in range(max_retries):
|
28 |
try:
|
29 |
-
#
|
30 |
-
self.
|
31 |
-
"
|
32 |
-
|
33 |
-
|
34 |
-
"Tonic/climate-guard-toxic-agent"
|
35 |
)
|
36 |
-
self.model.eval() # Set to evaluation mode
|
37 |
print("Model initialized successfully")
|
38 |
break
|
39 |
except Exception as e:
|
@@ -45,18 +43,11 @@ class TextClassifier:
|
|
45 |
def predict_single(self, text: str) -> int:
|
46 |
"""Predict single text instance"""
|
47 |
try:
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
padding=True
|
54 |
-
).to(self.device)
|
55 |
-
|
56 |
-
with torch.no_grad():
|
57 |
-
outputs = self.model(**inputs)
|
58 |
-
predictions = outputs.logits.argmax(-1)
|
59 |
-
return predictions.item()
|
60 |
except Exception as e:
|
61 |
print(f"Error in single prediction: {str(e)}")
|
62 |
return 0 # Return default prediction on error
|
@@ -114,15 +105,13 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
114 |
# Start tracking emissions
|
115 |
start_tracking()
|
116 |
|
117 |
-
# tracker.start_task("inference")
|
118 |
-
|
119 |
true_labels = test_dataset["label"]
|
120 |
|
121 |
# Initialize the model once
|
122 |
classifier = TextClassifier()
|
123 |
|
124 |
# Prepare batches
|
125 |
-
batch_size = 16
|
126 |
quotes = test_dataset["quote"]
|
127 |
num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
|
128 |
batches = [
|
@@ -162,7 +151,6 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
162 |
|
163 |
# Stop tracking emissions
|
164 |
emissions_data = stop_tracking()
|
165 |
-
# emissions_data = tracker.stop_task()
|
166 |
|
167 |
# Calculate accuracy
|
168 |
accuracy = accuracy_score(true_labels, predictions)
|
|
|
7 |
from concurrent.futures import ThreadPoolExecutor
|
8 |
from typing import List, Dict, Tuple
|
9 |
import torch
|
10 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
|
11 |
|
12 |
from .utils.evaluation import TextEvaluationRequest
|
13 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info, start_tracking, stop_tracking
|
|
|
26 |
max_retries = 3
|
27 |
for attempt in range(max_retries):
|
28 |
try:
|
29 |
+
# Initialize using pipeline instead
|
30 |
+
self.classifier = pipeline(
|
31 |
+
"text-classification",
|
32 |
+
model="Tonic/climate-guard-toxic-agent",
|
33 |
+
device=self.device
|
|
|
34 |
)
|
|
|
35 |
print("Model initialized successfully")
|
36 |
break
|
37 |
except Exception as e:
|
|
|
43 |
def predict_single(self, text: str) -> int:
|
44 |
"""Predict single text instance"""
|
45 |
try:
|
46 |
+
result = self.classifier(text)
|
47 |
+
# Extract the label index from the result
|
48 |
+
# Assuming the model outputs label indices directly
|
49 |
+
label = int(result[0]['label'].split('_')[0])
|
50 |
+
return label
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
except Exception as e:
|
52 |
print(f"Error in single prediction: {str(e)}")
|
53 |
return 0 # Return default prediction on error
|
|
|
105 |
# Start tracking emissions
|
106 |
start_tracking()
|
107 |
|
|
|
|
|
108 |
true_labels = test_dataset["label"]
|
109 |
|
110 |
# Initialize the model once
|
111 |
classifier = TextClassifier()
|
112 |
|
113 |
# Prepare batches
|
114 |
+
batch_size = 16
|
115 |
quotes = test_dataset["quote"]
|
116 |
num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
|
117 |
batches = [
|
|
|
151 |
|
152 |
# Stop tracking emissions
|
153 |
emissions_data = stop_tracking()
|
|
|
154 |
|
155 |
# Calculate accuracy
|
156 |
accuracy = accuracy_score(true_labels, predictions)
|