Update tasks/text.py
Browse files- tasks/text.py +104 -15
tasks/text.py
CHANGED
@@ -6,12 +6,16 @@ from sklearn.linear_model import LogisticRegression
|
|
6 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
7 |
from sklearn.model_selection import train_test_split
|
8 |
import pandas as pd
|
|
|
|
|
|
|
|
|
9 |
from .utils.evaluation import TextEvaluationRequest
|
10 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
11 |
|
12 |
router = APIRouter()
|
13 |
|
14 |
-
DESCRIPTION = "
|
15 |
ROUTE = "/text"
|
16 |
|
17 |
@router.post(ROUTE, tags=["Text Task"],
|
@@ -20,8 +24,8 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
20 |
"""
|
21 |
Evaluate text classification for climate disinformation detection.
|
22 |
|
23 |
-
Current Model:
|
24 |
-
-
|
25 |
- Used as a baseline for comparison
|
26 |
"""
|
27 |
# Get space info
|
@@ -52,28 +56,91 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
52 |
test_dataset = train_test["test"]
|
53 |
train_dataset = train_test["train"]
|
54 |
y_train=train_dataset['label']
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
|
61 |
# Start tracking emissions
|
62 |
tracker.start()
|
63 |
tracker.start_task("inference")
|
64 |
-
|
|
|
|
|
|
|
65 |
#--------------------------------------------------------------------------------------------
|
66 |
# YOUR MODEL INFERENCE CODE HERE
|
67 |
# Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
|
68 |
#--------------------------------------------------------------------------------------------
|
69 |
|
70 |
-
# Make
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
predictions=LR.predict(pd.DataFrame.sparse.from_spmatrix(tfidf_test))
|
76 |
-
|
77 |
#--------------------------------------------------------------------------------------------
|
78 |
# YOUR MODEL INFERENCE STOPS HERE
|
79 |
#--------------------------------------------------------------------------------------------
|
@@ -81,9 +148,31 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
81 |
|
82 |
# Stop tracking emissions
|
83 |
emissions_data = tracker.stop_task()
|
|
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
# Calculate accuracy
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
# Prepare results dictionary
|
89 |
results = {
|
|
|
6 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
7 |
from sklearn.model_selection import train_test_split
|
8 |
import pandas as pd
|
9 |
+
import tensorflow as tf
|
10 |
+
from transformers import DistilBertTokenizer
|
11 |
+
from transformers import TFDistilBertForSequenceClassification
|
12 |
+
from tensorflow.keras.models import load_model
|
13 |
from .utils.evaluation import TextEvaluationRequest
|
14 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
15 |
|
16 |
router = APIRouter()
|
17 |
|
18 |
+
DESCRIPTION = "DistilBert classification"
|
19 |
ROUTE = "/text"
|
20 |
|
21 |
@router.post(ROUTE, tags=["Text Task"],
|
|
|
24 |
"""
|
25 |
Evaluate text classification for climate disinformation detection.
|
26 |
|
27 |
+
Current Model: DistilBert classification
|
28 |
+
- DistilBert classification predictions from the label space (0-7)
|
29 |
- Used as a baseline for comparison
|
30 |
"""
|
31 |
# Get space info
|
|
|
56 |
test_dataset = train_test["test"]
|
57 |
train_dataset = train_test["train"]
|
58 |
y_train=train_dataset['label']
|
59 |
+
|
60 |
+
train_dataset = train_test["train"]
|
61 |
+
tn=pd.DataFrame([(i, j, k) for i,j,k in zip(train_dataset["quote"] , train_dataset["source"],
|
62 |
+
train_dataset["subsource"])], columns=['quote','source', 'subsource'])
|
63 |
+
test_dataset = train_test["test"]
|
64 |
+
tt=pd.DataFrame([(i, j, k) for i,j,k in zip(test_dataset["quote"] , test_dataset["source"],
|
65 |
+
test_dataset["subsource"])], columns=['quote','source', 'subsource'])
|
66 |
+
tt.fillna("",inplace=True)
|
67 |
+
tn.fillna("",inplace=True)
|
68 |
+
tn['text'] = tn[['quote', 'source','subsource']].agg(' '.join, axis=1)
|
69 |
+
tt['text'] = tn[['quote', 'source','subsource']].agg(' '.join, axis=1)
|
70 |
+
|
71 |
+
def clean_text(x):
|
72 |
+
pattern = r'[^a-zA-z0-9\s]'
|
73 |
+
text = re.sub(pattern, '', x)
|
74 |
+
return x
|
75 |
+
|
76 |
+
def clean_numbers(x):
|
77 |
+
if bool(re.search(r'\d', x)):
|
78 |
+
x = re.sub('[0-9]{5,}', '#####', x)
|
79 |
+
x = re.sub('[0-9]{4}', '####', x)
|
80 |
+
x = re.sub('[0-9]{3}', '###', x)
|
81 |
+
x = re.sub('[0-9]{2}', '##', x)
|
82 |
+
return x
|
83 |
+
|
84 |
+
contraction_dict = {"ain't": "is not", "aren't": "are not","can't": "cannot", "'cause": "because", "could've": "could have", "couldn't": "could not", "didn't": "did not", "doesn't": "does not", "don't": "do not", "hadn't": "had not", "hasn't": "has not", "haven't": "have not", "he'd": "he would","he'll": "he will", "he's": "he is", "how'd": "how did", "how'd'y": "how do you", "how'll": "how will", "how's": "how is", "I'd": "I would", "I'd've": "I would have", "I'll": "I will", "I'll've": "I will have","I'm": "I am", "I've": "I have", "i'd": "i would", "i'd've": "i would have", "i'll": "i will", "i'll've": "i will have","i'm": "i am", "i've": "i have", "isn't": "is not", "it'd": "it would", "it'd've": "it would have", "it'll": "it will", "it'll've": "it will have","it's": "it is", "let's": "let us", "ma'am": "madam", "mayn't": "may not", "might've": "might have","mightn't": "might not","mightn't've": "might not have", "must've": "must have", "mustn't": "must not", "mustn't've": "must not have", "needn't": "need not", "needn't've": "need not have","o'clock": "of the clock", "oughtn't": "ought not", "oughtn't've": "ought not have", "shan't": "shall not", "sha'n't": "shall not", "shan't've": "shall not have", "she'd": "she would", "she'd've": "she would have", "she'll": "she will", "she'll've": "she will have", "she's": "she is", "should've": "should have", "shouldn't": "should not", "shouldn't've": "should not have", "so've": "so have","so's": "so as", "this's": "this is","that'd": "that would", "that'd've": "that would have", "that's": "that is", "there'd": "there would", "there'd've": "there would have", "there's": "there is", "here's": "here is","they'd": "they would", "they'd've": "they would have", "they'll": "they will", "they'll've": "they will have", "they're": "they are", "they've": "they have", "to've": "to have", "wasn't": "was not", "we'd": "we would", "we'd've": "we would have", "we'll": "we will", "we'll've": "we will have", "we're": "we are", "we've": "we have", "weren't": "were not", "what'll": "what will", "what'll've": "what will have", "what're": "what are", "what's": "what is", "what've": "what have", "when's": "when is", "when've": "when have", "where'd": "where did", "where's": "where is", "where've": "where have", "who'll": "who will", "who'll've": "who will have", "who's": "who is", "who've": "who have", "why's": "why is", "why've": "why have", "will've": "will have", "won't": "will not", "won't've": "will not have", "would've": "would have", "wouldn't": "would not", "wouldn't've": "would not have", "y'all": "you all", "y'all'd": "you all would","y'all'd've": "you all would have","y'all're": "you all are","y'all've": "you all have","you'd": "you would", "you'd've": "you would have", "you'll": "you will", "you'll've": "you will have", "you're": "you are", "you've": "you have"}
|
85 |
+
def _get_contractions(contraction_dict):
|
86 |
+
contraction_re = re.compile('(%s)' % '|'.join(contraction_dict.keys()))
|
87 |
+
return contraction_dict, contraction_re
|
88 |
+
contractions, contractions_re = _get_contractions(contraction_dict)
|
89 |
+
def replace_contractions(text):
|
90 |
+
def replace(match):
|
91 |
+
return contractions[match.group(0)]
|
92 |
+
return contractions_re.sub(replace, text)
|
93 |
+
train_dataset_df = tn['quote'].apply(lambda x: x.lower())
|
94 |
+
test_dataset_df = tt['quote'].apply(lambda x: x.lower())
|
95 |
+
# Clean the text
|
96 |
+
train_dataset_df = train_dataset_df.apply(lambda x: clean_text(x))
|
97 |
+
test_dataset_df= test_dataset_df.apply(lambda x: clean_text(x))
|
98 |
+
# Clean numbers
|
99 |
+
train_dataset_df= train_dataset_df.apply(lambda x: clean_numbers(x))
|
100 |
+
test_dataset_df = test_dataset_df.apply(lambda x: clean_numbers(x))
|
101 |
+
# Clean Contractions
|
102 |
+
train_dataset_df = train_dataset_df.apply(lambda x: replace_contractions(x))
|
103 |
+
test_dataset_df = test_dataset_df.apply(lambda x: replace_contractions(x))
|
104 |
|
105 |
+
y_train_df=pd.DataFrame(train_dataset['label'], columns=['label'])
|
106 |
+
y_test_df=pd.DataFrame(test_dataset['label'], columns=['label'])
|
107 |
+
y_train_encoded = y_train_df['label'].astype('category').cat.codes
|
108 |
+
y_test_encoded = y_test_df['label'].astype('category').cat.codes
|
109 |
+
train_labels = y_train_encoded.to_list()
|
110 |
+
test_labels=y_test_encoded.to_list()
|
111 |
+
|
112 |
+
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
|
113 |
+
train_encodings = tokenizer(train_dataset_df.to_list(), truncation=True, padding=True)
|
114 |
+
val_encodings = tokenizer(test_dataset_df.to_list(), truncation=True, padding=True)
|
115 |
+
|
116 |
+
|
117 |
+
train_dataset_bert = tf.data.Dataset.from_tensor_slices((
|
118 |
+
dict(train_encodings),
|
119 |
+
train_labels
|
120 |
+
))
|
121 |
+
val_dataset_bert = tf.data.Dataset.from_tensor_slices((
|
122 |
+
dict(val_encodings),
|
123 |
+
test_labels
|
124 |
+
))
|
125 |
|
126 |
|
127 |
# Start tracking emissions
|
128 |
tracker.start()
|
129 |
tracker.start_task("inference")
|
130 |
+
model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=8)
|
131 |
+
|
132 |
+
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5, epsilon=1e-08)
|
133 |
+
model.compile(optimizer=optimizer, loss=model.hf_compute_loss, metrics=['accuracy'])
|
134 |
#--------------------------------------------------------------------------------------------
|
135 |
# YOUR MODEL INFERENCE CODE HERE
|
136 |
# Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
|
137 |
#--------------------------------------------------------------------------------------------
|
138 |
|
139 |
+
# Make predictions (placeholder for actual model inference)
|
140 |
+
|
141 |
+
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
|
142 |
+
|
143 |
+
model.fit(train_dataset_bert.shuffle(1000).batch(16),epochs=4,batch_size=16,validation_data=val_dataset_bert.shuffle(1000).batch(16),callbacks=[early_stopping])
|
|
|
|
|
144 |
#--------------------------------------------------------------------------------------------
|
145 |
# YOUR MODEL INFERENCE STOPS HERE
|
146 |
#--------------------------------------------------------------------------------------------
|
|
|
148 |
|
149 |
# Stop tracking emissions
|
150 |
emissions_data = tracker.stop_task()
|
151 |
+
|
152 |
|
153 |
+
save_directory = "BERT" # Change this to your preferred location
|
154 |
+
|
155 |
+
model.save_pretrained(save_directory)
|
156 |
+
tokenizer.save_pretrained(save_directory)
|
157 |
+
save_directory = "BERT"
|
158 |
+
loaded_tokenizer = DistilBertTokenizer.from_pretrained(save_directory)
|
159 |
+
loaded_model = TFDistilBertForSequenceClassification.from_pretrained(save_directory)
|
160 |
+
|
161 |
# Calculate accuracy
|
162 |
+
def predict_category(text):
|
163 |
+
predict_input = loaded_tokenizer.encode(text,
|
164 |
+
truncation=True,
|
165 |
+
padding=True,
|
166 |
+
return_tensors="tf")
|
167 |
+
output = loaded_model(predict_input)[0]
|
168 |
+
prediction_value = tf.argmax(output, axis=1).numpy()[0]
|
169 |
+
return prediction_value
|
170 |
+
#β-β-β-β-β-β-β-β-β-β-β-β-β-β-β-β-β-β-β-β-β-β-β-β-β-β- -
|
171 |
+
y_pred = []
|
172 |
+
for text_ in test_dataset_df.to_list():
|
173 |
+
y_pred.append(predict_category(text_))
|
174 |
+
|
175 |
+
accuracy_score(test_labels, y_pred)
|
176 |
|
177 |
# Prepare results dictionary
|
178 |
results = {
|