|
from fastapi import APIRouter |
|
from datetime import datetime |
|
from datasets import load_dataset |
|
from sklearn.metrics import accuracy_score |
|
from sklearn.linear_model import LogisticRegression |
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
from sklearn.model_selection import train_test_split |
|
import pandas as pd |
|
import tensorflow as tf |
|
from transformers import DistilBertTokenizer |
|
from transformers import TFDistilBertForSequenceClassification |
|
from transformers import logging |
|
logging.set_verbosity_error() |
|
logging.set_verbosity_warning() |
|
from .utils.evaluation import TextEvaluationRequest |
|
from .utils.emissions import tracker, clean_emissions_data, get_space_info |
|
import os |
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' |
|
|
|
router = APIRouter() |
|
|
|
DESCRIPTION = "DistilBert classification" |
|
ROUTE = "/text" |
|
|
|
@router.post(ROUTE, tags=["Text Task"], |
|
description=DESCRIPTION) |
|
async def evaluate_text(request: TextEvaluationRequest): |
|
""" |
|
Evaluate text classification for climate disinformation detection. |
|
|
|
Current Model: DistilBert classification |
|
- DistilBert classification predictions from the label space (0-7) |
|
- Used as a baseline for comparison |
|
""" |
|
|
|
username, space_url = get_space_info() |
|
|
|
|
|
LABEL_MAPPING = { |
|
"0_not_relevant": 0, |
|
"1_not_happening": 1, |
|
"2_not_human": 2, |
|
"3_not_bad": 3, |
|
"4_solutions_harmful_unnecessary": 4, |
|
"5_science_unreliable": 5, |
|
"6_proponents_biased": 6, |
|
"7_fossil_fuels_needed": 7 |
|
} |
|
|
|
|
|
dataset = load_dataset(request.dataset_name) |
|
|
|
|
|
dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]}) |
|
|
|
|
|
|
|
|
|
train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed) |
|
test_dataset = train_test["test"] |
|
train_dataset = train_test["train"] |
|
y_train=train_dataset['label'] |
|
|
|
train_dataset = train_test["train"] |
|
tn=pd.DataFrame([(i, j, k) for i,j,k in zip(train_dataset["quote"] , train_dataset["source"], |
|
train_dataset["subsource"])], columns=['quote','source', 'subsource']) |
|
test_dataset = train_test["test"] |
|
tt=pd.DataFrame([(i, j, k) for i,j,k in zip(test_dataset["quote"] , test_dataset["source"], |
|
test_dataset["subsource"])], columns=['quote','source', 'subsource']) |
|
tt.fillna("",inplace=True) |
|
tn.fillna("",inplace=True) |
|
|
|
tn['text'] = tn[['quote', 'source','subsource']].agg(' '.join, axis=1) |
|
tt['text'] = tn[['quote', 'source','subsource']].agg(' '.join, axis=1) |
|
|
|
def clean_text(x): |
|
pattern = r'[^a-zA-z0-9\s]' |
|
text = re.sub(pattern, '', x) |
|
return x |
|
|
|
def clean_numbers(x): |
|
if bool(re.search(r'\d', x)): |
|
x = re.sub('[0-9]{5,}', '#####', x) |
|
x = re.sub('[0-9]{4}', '####', x) |
|
x = re.sub('[0-9]{3}', '###', x) |
|
x = re.sub('[0-9]{2}', '##', x) |
|
return x |
|
|
|
contraction_dict = {"ain't": "is not", "aren't": "are not","can't": "cannot", "'cause": "because", "could've": "could have", "couldn't": "could not", "didn't": "did not", "doesn't": "does not", "don't": "do not", "hadn't": "had not", "hasn't": "has not", "haven't": "have not", "he'd": "he would","he'll": "he will", "he's": "he is", "how'd": "how did", "how'd'y": "how do you", "how'll": "how will", "how's": "how is", "I'd": "I would", "I'd've": "I would have", "I'll": "I will", "I'll've": "I will have","I'm": "I am", "I've": "I have", "i'd": "i would", "i'd've": "i would have", "i'll": "i will", "i'll've": "i will have","i'm": "i am", "i've": "i have", "isn't": "is not", "it'd": "it would", "it'd've": "it would have", "it'll": "it will", "it'll've": "it will have","it's": "it is", "let's": "let us", "ma'am": "madam", "mayn't": "may not", "might've": "might have","mightn't": "might not","mightn't've": "might not have", "must've": "must have", "mustn't": "must not", "mustn't've": "must not have", "needn't": "need not", "needn't've": "need not have","o'clock": "of the clock", "oughtn't": "ought not", "oughtn't've": "ought not have", "shan't": "shall not", "sha'n't": "shall not", "shan't've": "shall not have", "she'd": "she would", "she'd've": "she would have", "she'll": "she will", "she'll've": "she will have", "she's": "she is", "should've": "should have", "shouldn't": "should not", "shouldn't've": "should not have", "so've": "so have","so's": "so as", "this's": "this is","that'd": "that would", "that'd've": "that would have", "that's": "that is", "there'd": "there would", "there'd've": "there would have", "there's": "there is", "here's": "here is","they'd": "they would", "they'd've": "they would have", "they'll": "they will", "they'll've": "they will have", "they're": "they are", "they've": "they have", "to've": "to have", "wasn't": "was not", "we'd": "we would", "we'd've": "we would have", "we'll": "we will", "we'll've": "we will have", "we're": "we are", "we've": "we have", "weren't": "were not", "what'll": "what will", "what'll've": "what will have", "what're": "what are", "what's": "what is", "what've": "what have", "when's": "when is", "when've": "when have", "where'd": "where did", "where's": "where is", "where've": "where have", "who'll": "who will", "who'll've": "who will have", "who's": "who is", "who've": "who have", "why's": "why is", "why've": "why have", "will've": "will have", "won't": "will not", "won't've": "will not have", "would've": "would have", "wouldn't": "would not", "wouldn't've": "would not have", "y'all": "you all", "y'all'd": "you all would","y'all'd've": "you all would have","y'all're": "you all are","y'all've": "you all have","you'd": "you would", "you'd've": "you would have", "you'll": "you will", "you'll've": "you will have", "you're": "you are", "you've": "you have"} |
|
|
|
def _get_contractions(contraction_dict): |
|
contraction_re = re.compile('(%s)' % '|'.join(contraction_dict.keys())) |
|
return contraction_dict, contraction_re |
|
|
|
contractions, contractions_re = _get_contractions(contraction_dict) |
|
|
|
def replace_contractions(text): |
|
def replace(match): |
|
return contractions[match.group(0)] |
|
return contractions_re.sub(replace, text) |
|
|
|
train_dataset_df = tn['quote'].apply(lambda x: x.lower()) |
|
test_dataset_df = tt['quote'].apply(lambda x: x.lower()) |
|
|
|
|
|
train_dataset_df = train_dataset_df.apply(lambda x: clean_text(x)) |
|
test_dataset_df= test_dataset_df.apply(lambda x: clean_text(x)) |
|
|
|
|
|
train_dataset_df= train_dataset_df.apply(lambda x: clean_numbers(x)) |
|
test_dataset_df = test_dataset_df.apply(lambda x: clean_numbers(x)) |
|
|
|
|
|
train_dataset_df = train_dataset_df.apply(lambda x: replace_contractions(x)) |
|
test_dataset_df = test_dataset_df.apply(lambda x: replace_contractions(x)) |
|
|
|
|
|
y_train_df=pd.DataFrame(train_dataset['label'], columns=['label']) |
|
y_test_df=pd.DataFrame(test_dataset['label'], columns=['label']) |
|
y_train_encoded = y_train_df['label'].astype('category').cat.codes |
|
y_test_encoded = y_test_df['label'].astype('category').cat.codes |
|
train_labels = y_train_encoded.to_list() |
|
test_labels=y_test_encoded.to_list() |
|
|
|
|
|
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') |
|
train_encodings = tokenizer(train_dataset_df.to_list(), truncation=True, padding=True) |
|
val_encodings = tokenizer(test_dataset_df.to_list(), truncation=True, padding=True) |
|
|
|
|
|
train_dataset_bert = tf.data.Dataset.from_tensor_slices(( |
|
dict(train_encodings), |
|
train_labels |
|
)) |
|
val_dataset_bert = tf.data.Dataset.from_tensor_slices(( |
|
dict(val_encodings), |
|
test_labels |
|
)) |
|
|
|
model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=8) |
|
|
|
|
|
|
|
tracker.start() |
|
tracker.start_task("inference") |
|
|
|
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5, epsilon=1e-08) |
|
model.compile(optimizer=optimizer, loss=model.hf_compute_loss, metrics=['accuracy']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True) |
|
|
|
model.fit(train_dataset_bert.shuffle(1000).batch(16),epochs=2,batch_size=16,validation_data=val_dataset_bert.shuffle(1000).batch(16),callbacks=[early_stopping]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
emissions_data = tracker.stop_task() |
|
|
|
|
|
|
|
def predict_category(text): |
|
predict_input =tokenizer.encode(text, |
|
truncation=True, |
|
padding=True, |
|
return_tensors="tf") |
|
output = model(predict_input)[0] |
|
prediction_value = tf.argmax(output, axis=1).numpy()[0] |
|
return prediction_value |
|
|
|
y_pred = [] |
|
for text_ in test_dataset_df.to_list(): |
|
y_pred.append(predict_category(text_)) |
|
|
|
accuracy_score(test_labels, y_pred) |
|
|
|
|
|
results = { |
|
"username": username, |
|
"space_url": space_url, |
|
"submission_timestamp": datetime.now().isoformat(), |
|
"model_description": DESCRIPTION, |
|
"accuracy": float(accuracy), |
|
"energy_consumed_wh": emissions_data.energy_consumed * 1000, |
|
"emissions_gco2eq": emissions_data.emissions * 1000, |
|
"emissions_data": clean_emissions_data(emissions_data), |
|
"api_route": ROUTE, |
|
"dataset_config": { |
|
"dataset_name": request.dataset_name, |
|
"test_size": request.test_size, |
|
"test_seed": request.test_seed |
|
} |
|
} |
|
|
|
return results |