|
import datasets |
|
import logging |
|
import json |
|
import pandas as pd |
|
|
|
|
|
def text_classificaiton_match_label_case_unsensative(id2label_mapping, label): |
|
for model_label in id2label_mapping.keys(): |
|
if model_label.upper() == label.upper(): |
|
return model_label, label |
|
return None, label |
|
|
|
|
|
def text_classification_map_model_and_dataset_labels(id2label, dataset_features): |
|
id2label_mapping = {id2label[k]: None for k in id2label.keys()} |
|
dataset_labels = None |
|
for feature in dataset_features.values(): |
|
if not isinstance(feature, datasets.ClassLabel): |
|
continue |
|
if len(feature.names) != len(id2label_mapping.keys()): |
|
continue |
|
|
|
dataset_labels = feature.names |
|
|
|
for label in feature.names: |
|
if label in id2label_mapping.keys(): |
|
model_label = label |
|
else: |
|
|
|
model_label, label = text_classificaiton_match_label_case_unsensative(id2label_mapping, label) |
|
if model_label is not None: |
|
id2label_mapping[model_label] = label |
|
else: |
|
print(f"Label {label} is not found in model labels") |
|
|
|
return id2label_mapping, dataset_labels |
|
|
|
''' |
|
params: |
|
column_mapping: dict |
|
example: { |
|
"text": "sentences", |
|
"label": { |
|
"label0": "LABEL_0", |
|
"label1": "LABEL_1" |
|
} |
|
} |
|
ppl: pipeline |
|
''' |
|
def check_column_mapping_keys_validity(column_mapping, ppl): |
|
|
|
column_mapping = json.loads(column_mapping) |
|
if "data" not in column_mapping.keys(): |
|
return True |
|
user_labels = set([pair[0] for pair in column_mapping["data"]]) |
|
model_labels = set([pair[1] for pair in column_mapping["data"]]) |
|
|
|
id2label = ppl.model.config.id2label |
|
original_labels = set(id2label.values()) |
|
|
|
return user_labels == model_labels == original_labels |
|
|
|
def infer_text_input_column(column_mapping, dataset_features): |
|
|
|
infer_text_input_column = True |
|
feature_map_df = None |
|
if "text" in column_mapping.keys(): |
|
dataset_text_column = column_mapping["text"] |
|
if dataset_text_column in dataset_features.keys(): |
|
infer_text_input_column = False |
|
else: |
|
logging.warning(f"Provided {dataset_text_column} is not in Dataset columns") |
|
|
|
if infer_text_input_column: |
|
|
|
candidates = [f for f in dataset_features if dataset_features[f].dtype == "string"] |
|
feature_map_df = pd.DataFrame({ |
|
"Dataset Features": [candidates[0]], |
|
"Model Input Features": ["text"] |
|
}) |
|
if len(candidates) > 0: |
|
logging.debug(f"Candidates are {candidates}") |
|
column_mapping["text"] = candidates[0] |
|
|
|
return column_mapping, feature_map_df |
|
|
|
def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split): |
|
|
|
ds = datasets.load_dataset(d_id, config)[split] |
|
try: |
|
dataset_features = ds.features |
|
except AttributeError: |
|
|
|
return None, None, None, None, None |
|
|
|
column_mapping, feature_map_df = infer_text_input_column(column_mapping, dataset_features) |
|
|
|
|
|
df = ds.to_pandas() |
|
|
|
|
|
id2label_mapping = {} |
|
id2label = ppl.model.config.id2label |
|
label2id = {v: k for k, v in id2label.items()} |
|
|
|
|
|
id2label_mapping, dataset_labels = text_classification_map_model_and_dataset_labels(id2label, dataset_features) |
|
id2label_mapping_dataset_model = { |
|
v: k for k, v in id2label_mapping.items() |
|
} |
|
|
|
if "data" in column_mapping.keys(): |
|
if isinstance(column_mapping["data"], list): |
|
|
|
for user_label, model_label in column_mapping["data"]: |
|
id2label_mapping[model_label] = user_label |
|
elif None in id2label_mapping.values(): |
|
column_mapping["label"] = { |
|
i: None for i in id2label.keys() |
|
} |
|
return column_mapping, None, None, None, feature_map_df |
|
|
|
id2label_df = pd.DataFrame({ |
|
"Dataset Labels": dataset_labels, |
|
"Model Prediction Labels": [id2label_mapping_dataset_model[label] for label in dataset_labels], |
|
}) |
|
|
|
|
|
prediction_input = None |
|
prediction_result = None |
|
try: |
|
|
|
prediction_input = df.head(1).at[0, column_mapping["text"]] |
|
results = ppl({"text": prediction_input}, top_k=None) |
|
prediction_result = { |
|
f'{result["label"]}({label2id[result["label"]]})': result["score"] for result in results |
|
} |
|
except Exception as e: |
|
|
|
print(e, '>>>> error') |
|
return column_mapping, prediction_input, None, id2label_df, feature_map_df |
|
|
|
prediction_result = { |
|
f'[{label2id[result["label"]]}]{result["label"]}(original) - {id2label_mapping[result["label"]]}(mapped)': result["score"] for result in results |
|
} |
|
|
|
if "data" not in column_mapping.keys(): |
|
|
|
column_mapping["label"] = { |
|
str(i): id2label_mapping_dataset_model[label] for i, label in zip(id2label.keys(), dataset_labels) |
|
} |
|
|
|
return column_mapping, prediction_input, prediction_result, id2label_df, feature_map_df |
|
|