hardiktiwari's picture
Upload 244 files
33d4721 verified
"""
Copyright 2023 The HuggingFace Team
"""
import os
from dataclasses import dataclass
from typing import Union
from autotrain.backends.base import AVAILABLE_HARDWARE
from autotrain.backends.endpoints import EndpointsRunner
from autotrain.backends.local import LocalRunner
from autotrain.backends.ngc import NGCRunner
from autotrain.backends.nvcf import NVCFRunner
from autotrain.backends.spaces import SpaceRunner
from autotrain.dataset import (
AutoTrainDataset,
AutoTrainImageClassificationDataset,
AutoTrainImageRegressionDataset,
AutoTrainObjectDetectionDataset,
AutoTrainVLMDataset,
)
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams
from autotrain.trainers.image_classification.params import ImageClassificationParams
from autotrain.trainers.image_regression.params import ImageRegressionParams
from autotrain.trainers.object_detection.params import ObjectDetectionParams
from autotrain.trainers.sent_transformers.params import SentenceTransformersParams
from autotrain.trainers.seq2seq.params import Seq2SeqParams
from autotrain.trainers.tabular.params import TabularParams
from autotrain.trainers.text_classification.params import TextClassificationParams
from autotrain.trainers.text_regression.params import TextRegressionParams
from autotrain.trainers.token_classification.params import TokenClassificationParams
from autotrain.trainers.vlm.params import VLMTrainingParams
def tabular_munge_data(params, local):
if isinstance(params.target_columns, str):
col_map_label = [params.target_columns]
else:
col_map_label = params.target_columns
task = params.task
if task == "classification" and len(col_map_label) > 1:
task = "tabular_multi_label_classification"
elif task == "classification" and len(col_map_label) == 1:
task = "tabular_multi_class_classification"
elif task == "regression" and len(col_map_label) > 1:
task = "tabular_multi_column_regression"
elif task == "regression" and len(col_map_label) == 1:
task = "tabular_single_column_regression"
else:
raise Exception("Please select a valid task.")
exts = ["csv", "jsonl"]
ext_to_use = None
for ext in exts:
path = f"{params.data_path}/{params.train_split}.{ext}"
if os.path.exists(path):
ext_to_use = ext
break
train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}"
if params.valid_split is not None:
valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}"
else:
valid_data_path = None
if os.path.exists(train_data_path):
dset = AutoTrainDataset(
train_data=[train_data_path],
task=task,
token=params.token,
project_name=params.project_name,
username=params.username,
column_mapping={"id": params.id_column, "label": col_map_label},
valid_data=[valid_data_path] if valid_data_path is not None else None,
percent_valid=None, # TODO: add to UI
local=local,
ext=ext_to_use,
)
params.data_path = dset.prepare()
params.valid_split = "validation"
params.id_column = "autotrain_id"
if len(col_map_label) == 1:
params.target_columns = ["autotrain_label"]
else:
params.target_columns = [f"autotrain_label_{i}" for i in range(len(col_map_label))]
return params
def llm_munge_data(params, local):
exts = ["csv", "jsonl"]
ext_to_use = None
for ext in exts:
path = f"{params.data_path}/{params.train_split}.{ext}"
if os.path.exists(path):
ext_to_use = ext
break
train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}"
if params.valid_split is not None:
valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}"
else:
valid_data_path = None
if os.path.exists(train_data_path):
col_map = {"text": params.text_column}
if params.rejected_text_column is not None:
col_map["rejected_text"] = params.rejected_text_column
if params.prompt_text_column is not None:
col_map["prompt"] = params.prompt_text_column
dset = AutoTrainDataset(
train_data=[train_data_path],
task="lm_training",
token=params.token,
project_name=params.project_name,
username=params.username,
column_mapping=col_map,
valid_data=[valid_data_path] if valid_data_path is not None else None,
percent_valid=None, # TODO: add to UI
local=local,
ext=ext_to_use,
)
params.data_path = dset.prepare()
params.valid_split = None
params.text_column = "autotrain_text"
params.rejected_text_column = "autotrain_rejected_text"
params.prompt_text_column = "autotrain_prompt"
return params
def seq2seq_munge_data(params, local):
exts = ["csv", "jsonl"]
ext_to_use = None
for ext in exts:
path = f"{params.data_path}/{params.train_split}.{ext}"
if os.path.exists(path):
ext_to_use = ext
break
train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}"
if params.valid_split is not None:
valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}"
else:
valid_data_path = None
if os.path.exists(train_data_path):
dset = AutoTrainDataset(
train_data=[train_data_path],
task="seq2seq",
token=params.token,
project_name=params.project_name,
username=params.username,
column_mapping={"text": params.text_column, "label": params.target_column},
valid_data=[valid_data_path] if valid_data_path is not None else None,
percent_valid=None, # TODO: add to UI
local=local,
ext=ext_to_use,
)
params.data_path = dset.prepare()
params.valid_split = "validation"
params.text_column = "autotrain_text"
params.target_column = "autotrain_label"
return params
def text_clf_munge_data(params, local):
exts = ["csv", "jsonl"]
ext_to_use = None
for ext in exts:
path = f"{params.data_path}/{params.train_split}.{ext}"
if os.path.exists(path):
ext_to_use = ext
break
train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}"
if params.valid_split is not None:
valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}"
else:
valid_data_path = None
if os.path.exists(train_data_path):
dset = AutoTrainDataset(
train_data=[train_data_path],
valid_data=[valid_data_path] if valid_data_path is not None else None,
task="text_multi_class_classification",
token=params.token,
project_name=params.project_name,
username=params.username,
column_mapping={"text": params.text_column, "label": params.target_column},
percent_valid=None, # TODO: add to UI
local=local,
convert_to_class_label=True,
ext=ext_to_use,
)
params.data_path = dset.prepare()
params.valid_split = "validation"
params.text_column = "autotrain_text"
params.target_column = "autotrain_label"
return params
def text_reg_munge_data(params, local):
exts = ["csv", "jsonl"]
ext_to_use = None
for ext in exts:
path = f"{params.data_path}/{params.train_split}.{ext}"
if os.path.exists(path):
ext_to_use = ext
break
train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}"
if params.valid_split is not None:
valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}"
else:
valid_data_path = None
if os.path.exists(train_data_path):
dset = AutoTrainDataset(
train_data=[train_data_path],
valid_data=[valid_data_path] if valid_data_path is not None else None,
task="text_single_column_regression",
token=params.token,
project_name=params.project_name,
username=params.username,
column_mapping={"text": params.text_column, "label": params.target_column},
percent_valid=None, # TODO: add to UI
local=local,
convert_to_class_label=False,
ext=ext_to_use,
)
params.data_path = dset.prepare()
params.valid_split = "validation"
params.text_column = "autotrain_text"
params.target_column = "autotrain_label"
return params
def token_clf_munge_data(params, local):
exts = ["csv", "jsonl"]
ext_to_use = None
for ext in exts:
path = f"{params.data_path}/{params.train_split}.{ext}"
if os.path.exists(path):
ext_to_use = ext
break
train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}"
if params.valid_split is not None:
valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}"
else:
valid_data_path = None
if os.path.exists(train_data_path):
dset = AutoTrainDataset(
train_data=[train_data_path],
valid_data=[valid_data_path] if valid_data_path is not None else None,
task="text_token_classification",
token=params.token,
project_name=params.project_name,
username=params.username,
column_mapping={"text": params.tokens_column, "label": params.tags_column},
percent_valid=None, # TODO: add to UI
local=local,
convert_to_class_label=True,
ext=ext_to_use,
)
params.data_path = dset.prepare()
params.valid_split = "validation"
params.tokens_column = "autotrain_text"
params.tags_column = "autotrain_label"
return params
def img_clf_munge_data(params, local):
train_data_path = f"{params.data_path}/{params.train_split}"
if params.valid_split is not None:
valid_data_path = f"{params.data_path}/{params.valid_split}"
else:
valid_data_path = None
if os.path.isdir(train_data_path):
dset = AutoTrainImageClassificationDataset(
train_data=train_data_path,
valid_data=valid_data_path,
token=params.token,
project_name=params.project_name,
username=params.username,
local=local,
)
params.data_path = dset.prepare()
params.valid_split = "validation"
params.image_column = "autotrain_image"
params.target_column = "autotrain_label"
return params
def img_obj_detect_munge_data(params, local):
train_data_path = f"{params.data_path}/{params.train_split}"
if params.valid_split is not None:
valid_data_path = f"{params.data_path}/{params.valid_split}"
else:
valid_data_path = None
if os.path.isdir(train_data_path):
dset = AutoTrainObjectDetectionDataset(
train_data=train_data_path,
valid_data=valid_data_path,
token=params.token,
project_name=params.project_name,
username=params.username,
local=local,
)
params.data_path = dset.prepare()
params.valid_split = "validation"
params.image_column = "autotrain_image"
params.objects_column = "autotrain_objects"
return params
def sent_transformers_munge_data(params, local):
exts = ["csv", "jsonl"]
ext_to_use = None
for ext in exts:
path = f"{params.data_path}/{params.train_split}.{ext}"
if os.path.exists(path):
ext_to_use = ext
break
train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}"
if params.valid_split is not None:
valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}"
else:
valid_data_path = None
if os.path.exists(train_data_path):
dset = AutoTrainDataset(
train_data=[train_data_path],
valid_data=[valid_data_path] if valid_data_path is not None else None,
task="sentence_transformers",
token=params.token,
project_name=params.project_name,
username=params.username,
column_mapping={
"sentence1": params.sentence1_column,
"sentence2": params.sentence2_column,
"sentence3": params.sentence3_column,
"target": params.target_column,
},
percent_valid=None, # TODO: add to UI
local=local,
convert_to_class_label=True if params.trainer == "pair_class" else False,
ext=ext_to_use,
)
params.data_path = dset.prepare()
params.valid_split = "validation"
params.sentence1_column = "autotrain_sentence1"
params.sentence2_column = "autotrain_sentence2"
params.sentence3_column = "autotrain_sentence3"
params.target_column = "autotrain_target"
return params
def img_reg_munge_data(params, local):
train_data_path = f"{params.data_path}/{params.train_split}"
if params.valid_split is not None:
valid_data_path = f"{params.data_path}/{params.valid_split}"
else:
valid_data_path = None
if os.path.isdir(train_data_path):
dset = AutoTrainImageRegressionDataset(
train_data=train_data_path,
valid_data=valid_data_path,
token=params.token,
project_name=params.project_name,
username=params.username,
local=local,
)
params.data_path = dset.prepare()
params.valid_split = "validation"
params.image_column = "autotrain_image"
params.target_column = "autotrain_label"
return params
def vlm_munge_data(params, local):
train_data_path = f"{params.data_path}/{params.train_split}"
if params.valid_split is not None:
valid_data_path = f"{params.data_path}/{params.valid_split}"
else:
valid_data_path = None
if os.path.exists(train_data_path):
col_map = {"text": params.text_column}
if params.prompt_text_column is not None:
col_map["prompt"] = params.prompt_text_column
dset = AutoTrainVLMDataset(
train_data=train_data_path,
token=params.token,
project_name=params.project_name,
username=params.username,
column_mapping=col_map,
valid_data=valid_data_path if valid_data_path is not None else None,
percent_valid=None, # TODO: add to UI
local=local,
)
params.data_path = dset.prepare()
params.text_column = "autotrain_text"
params.image_column = "autotrain_image"
params.prompt_text_column = "autotrain_prompt"
return params
def ext_qa_munge_data(params, local):
exts = ["csv", "jsonl"]
ext_to_use = None
for ext in exts:
path = f"{params.data_path}/{params.train_split}.{ext}"
if os.path.exists(path):
ext_to_use = ext
break
train_data_path = f"{params.data_path}/{params.train_split}.{ext_to_use}"
if params.valid_split is not None:
valid_data_path = f"{params.data_path}/{params.valid_split}.{ext_to_use}"
else:
valid_data_path = None
if os.path.exists(train_data_path):
dset = AutoTrainDataset(
train_data=[train_data_path],
valid_data=[valid_data_path] if valid_data_path is not None else None,
task="text_extractive_question_answering",
token=params.token,
project_name=params.project_name,
username=params.username,
column_mapping={
"text": params.text_column,
"question": params.question_column,
"answer": params.answer_column,
},
percent_valid=None, # TODO: add to UI
local=local,
convert_to_class_label=True,
ext=ext_to_use,
)
params.data_path = dset.prepare()
params.valid_split = "validation"
params.text_column = "autotrain_text"
params.question_column = "autotrain_question"
params.answer_column = "autotrain_answer"
return params
@dataclass
class AutoTrainProject:
"""
A class to train an AutoTrain project
Attributes
----------
params : Union[
LLMTrainingParams,
TextClassificationParams,
TabularParams,
Seq2SeqParams,
ImageClassificationParams,
TextRegressionParams,
ObjectDetectionParams,
TokenClassificationParams,
SentenceTransformersParams,
ImageRegressionParams,
ExtractiveQuestionAnsweringParams,
VLMTrainingParams,
]
The parameters for the AutoTrain project.
backend : str
The backend to be used for the AutoTrain project. It should be one of the following:
- local
- spaces-a10g-large
- spaces-a10g-small
- spaces-a100-large
- spaces-t4-medium
- spaces-t4-small
- spaces-cpu-upgrade
- spaces-cpu-basic
- spaces-l4x1
- spaces-l4x4
- spaces-l40sx1
- spaces-l40sx4
- spaces-l40sx8
- spaces-a10g-largex2
- spaces-a10g-largex4
process : bool
Flag to indicate if the params and dataset should be processed. If your data format is not AutoTrain-readable, set it to True. Set it to True when in doubt. Defaults to False.
Methods
-------
__post_init__():
Validates the backend attribute.
create():
Creates a runner based on the backend and initializes the AutoTrain project.
"""
params: Union[
LLMTrainingParams,
TextClassificationParams,
TabularParams,
Seq2SeqParams,
ImageClassificationParams,
TextRegressionParams,
ObjectDetectionParams,
TokenClassificationParams,
SentenceTransformersParams,
ImageRegressionParams,
ExtractiveQuestionAnsweringParams,
VLMTrainingParams,
]
backend: str
process: bool = False
def __post_init__(self):
self.local = self.backend.startswith("local")
if self.backend not in AVAILABLE_HARDWARE:
raise ValueError(f"Invalid backend: {self.backend}")
def _process_params_data(self):
if isinstance(self.params, LLMTrainingParams):
return llm_munge_data(self.params, self.local)
elif isinstance(self.params, ExtractiveQuestionAnsweringParams):
return ext_qa_munge_data(self.params, self.local)
elif isinstance(self.params, ImageClassificationParams):
return img_clf_munge_data(self.params, self.local)
elif isinstance(self.params, ImageRegressionParams):
return img_reg_munge_data(self.params, self.local)
elif isinstance(self.params, ObjectDetectionParams):
return img_obj_detect_munge_data(self.params, self.local)
elif isinstance(self.params, SentenceTransformersParams):
return sent_transformers_munge_data(self.params, self.local)
elif isinstance(self.params, Seq2SeqParams):
return seq2seq_munge_data(self.params, self.local)
elif isinstance(self.params, TabularParams):
return tabular_munge_data(self.params, self.local)
elif isinstance(self.params, TextClassificationParams):
return text_clf_munge_data(self.params, self.local)
elif isinstance(self.params, TextRegressionParams):
return text_reg_munge_data(self.params, self.local)
elif isinstance(self.params, TokenClassificationParams):
return token_clf_munge_data(self.params, self.local)
elif isinstance(self.params, VLMTrainingParams):
return vlm_munge_data(self.params, self.local)
else:
raise Exception("Invalid params class")
def create(self):
if self.process:
self.params = self._process_params_data()
if self.backend.startswith("local"):
runner = LocalRunner(params=self.params, backend=self.backend)
return runner.create()
elif self.backend.startswith("spaces-"):
runner = SpaceRunner(params=self.params, backend=self.backend)
return runner.create()
elif self.backend.startswith("ep-"):
runner = EndpointsRunner(params=self.params, backend=self.backend)
return runner.create()
elif self.backend.startswith("ngc-"):
runner = NGCRunner(params=self.params, backend=self.backend)
return runner.create()
elif self.backend.startswith("nvcf-"):
runner = NVCFRunner(params=self.params, backend=self.backend)
return runner.create()
else:
raise NotImplementedError