Spaces:
Sleeping
Sleeping
import os | |
from dataclasses import dataclass | |
import requests | |
import yaml | |
from autotrain import logger | |
from autotrain.project import ( | |
AutoTrainProject, | |
ext_qa_munge_data, | |
img_clf_munge_data, | |
img_obj_detect_munge_data, | |
img_reg_munge_data, | |
llm_munge_data, | |
sent_transformers_munge_data, | |
seq2seq_munge_data, | |
tabular_munge_data, | |
text_clf_munge_data, | |
text_reg_munge_data, | |
token_clf_munge_data, | |
vlm_munge_data, | |
) | |
from autotrain.tasks import TASKS | |
from autotrain.trainers.clm.params import LLMTrainingParams | |
from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams | |
from autotrain.trainers.image_classification.params import ImageClassificationParams | |
from autotrain.trainers.image_regression.params import ImageRegressionParams | |
from autotrain.trainers.object_detection.params import ObjectDetectionParams | |
from autotrain.trainers.sent_transformers.params import SentenceTransformersParams | |
from autotrain.trainers.seq2seq.params import Seq2SeqParams | |
from autotrain.trainers.tabular.params import TabularParams | |
from autotrain.trainers.text_classification.params import TextClassificationParams | |
from autotrain.trainers.text_regression.params import TextRegressionParams | |
from autotrain.trainers.token_classification.params import TokenClassificationParams | |
from autotrain.trainers.vlm.params import VLMTrainingParams | |
class AutoTrainConfigParser: | |
""" | |
AutoTrainConfigParser is a class responsible for parsing and validating the yaml configuration | |
required to run various tasks in the AutoTrain framework. It supports loading configurations | |
from both local files and remote URLs, and maps task aliases to their respective parameters | |
and data munging functions. | |
Attributes: | |
config_path (str): Path or URL to the configuration file. | |
config (dict): Parsed configuration data. | |
task_param_map (dict): Mapping of task names to their parameter classes. | |
munge_data_map (dict): Mapping of task names to their data munging functions. | |
task_aliases (dict): Mapping of task aliases to their canonical task names. | |
task (str): The resolved task name from the configuration. | |
backend (str): The backend specified in the configuration. | |
parsed_config (dict): The parsed configuration parameters. | |
Methods: | |
__post_init__(): Initializes the parser, loads the configuration, and validates required fields. | |
_parse_config(): Parses the configuration and extracts relevant parameters based on the task. | |
run(): Executes the task with the parsed configuration. | |
""" | |
config_path: str | |
def __post_init__(self): | |
if self.config_path.startswith("http"): | |
response = requests.get(self.config_path) | |
if response.status_code == 200: | |
self.config = yaml.safe_load(response.content) | |
else: | |
raise ValueError("Failed to retrieve YAML file.") | |
else: | |
with open(self.config_path, "r") as f: | |
self.config = yaml.safe_load(f) | |
self.task_param_map = { | |
"lm_training": LLMTrainingParams, | |
"image_binary_classification": ImageClassificationParams, | |
"image_multi_class_classification": ImageClassificationParams, | |
"image_object_detection": ObjectDetectionParams, | |
"seq2seq": Seq2SeqParams, | |
"tabular": TabularParams, | |
"text_binary_classification": TextClassificationParams, | |
"text_multi_class_classification": TextClassificationParams, | |
"text_single_column_regression": TextRegressionParams, | |
"text_token_classification": TokenClassificationParams, | |
"sentence_transformers": SentenceTransformersParams, | |
"image_single_column_regression": ImageRegressionParams, | |
"vlm": VLMTrainingParams, | |
"text_extractive_question_answering": ExtractiveQuestionAnsweringParams, | |
} | |
self.munge_data_map = { | |
"lm_training": llm_munge_data, | |
"tabular": tabular_munge_data, | |
"seq2seq": seq2seq_munge_data, | |
"image_multi_class_classification": img_clf_munge_data, | |
"image_object_detection": img_obj_detect_munge_data, | |
"text_multi_class_classification": text_clf_munge_data, | |
"text_token_classification": token_clf_munge_data, | |
"text_single_column_regression": text_reg_munge_data, | |
"sentence_transformers": sent_transformers_munge_data, | |
"image_single_column_regression": img_reg_munge_data, | |
"vlm": vlm_munge_data, | |
"text_extractive_question_answering": ext_qa_munge_data, | |
} | |
self.task_aliases = { | |
"llm": "lm_training", | |
"llm-sft": "lm_training", | |
"llm-orpo": "lm_training", | |
"llm-generic": "lm_training", | |
"llm-dpo": "lm_training", | |
"llm-reward": "lm_training", | |
"image_binary_classification": "image_multi_class_classification", | |
"image-binary-classification": "image_multi_class_classification", | |
"image_classification": "image_multi_class_classification", | |
"image-classification": "image_multi_class_classification", | |
"seq2seq": "seq2seq", | |
"tabular": "tabular", | |
"text_binary_classification": "text_multi_class_classification", | |
"text-binary-classification": "text_multi_class_classification", | |
"text_classification": "text_multi_class_classification", | |
"text-classification": "text_multi_class_classification", | |
"text_single_column_regression": "text_single_column_regression", | |
"text-single-column-regression": "text_single_column_regression", | |
"text_regression": "text_single_column_regression", | |
"text-regression": "text_single_column_regression", | |
"token_classification": "text_token_classification", | |
"token-classification": "text_token_classification", | |
"image_object_detection": "image_object_detection", | |
"image-object-detection": "image_object_detection", | |
"object_detection": "image_object_detection", | |
"object-detection": "image_object_detection", | |
"st": "sentence_transformers", | |
"st:pair": "sentence_transformers", | |
"st:pair_class": "sentence_transformers", | |
"st:pair_score": "sentence_transformers", | |
"st:triplet": "sentence_transformers", | |
"st:qa": "sentence_transformers", | |
"sentence-transformers:pair": "sentence_transformers", | |
"sentence-transformers:pair_class": "sentence_transformers", | |
"sentence-transformers:pair_score": "sentence_transformers", | |
"sentence-transformers:triplet": "sentence_transformers", | |
"sentence-transformers:qa": "sentence_transformers", | |
"image_single_column_regression": "image_single_column_regression", | |
"image-single-column-regression": "image_single_column_regression", | |
"image_regression": "image_single_column_regression", | |
"image-regression": "image_single_column_regression", | |
"image-scoring": "image_single_column_regression", | |
"vlm:captioning": "vlm", | |
"vlm:vqa": "vlm", | |
"extractive_question_answering": "text_extractive_question_answering", | |
"ext_qa": "text_extractive_question_answering", | |
"ext-qa": "text_extractive_question_answering", | |
"extractive-qa": "text_extractive_question_answering", | |
} | |
task = self.config.get("task") | |
self.task = self.task_aliases.get(task, task) | |
if self.task is None: | |
raise ValueError("Task is required in the configuration file") | |
if self.task not in TASKS: | |
raise ValueError(f"Task `{self.task}` is not supported") | |
self.backend = self.config.get("backend") | |
if self.backend is None: | |
raise ValueError("Backend is required in the configuration file") | |
logger.info(f"Running task: {self.task}") | |
logger.info(f"Using backend: {self.backend}") | |
self.parsed_config = self._parse_config() | |
def _parse_config(self): | |
params = { | |
"model": self.config["base_model"], | |
"project_name": self.config["project_name"], | |
} | |
params["data_path"] = self.config["data"]["path"] | |
if self.task == "lm_training": | |
params["chat_template"] = self.config["data"]["chat_template"] | |
if "-" in self.config["task"]: | |
params["trainer"] = self.config["task"].split("-")[1] | |
if params["trainer"] == "generic": | |
params["trainer"] = "default" | |
if params["trainer"] not in ["sft", "orpo", "dpo", "reward", "default"]: | |
raise ValueError("Invalid LLM training task") | |
if self.task == "sentence_transformers": | |
params["trainer"] = self.config["task"].split(":")[1] | |
if self.task == "vlm": | |
params["trainer"] = self.config["task"].split(":")[1] | |
for k, v in self.config["data"]["column_mapping"].items(): | |
params[k] = v | |
params["train_split"] = self.config["data"]["train_split"] | |
params["valid_split"] = self.config["data"]["valid_split"] | |
params["log"] = self.config["log"] | |
if "hub" in self.config: | |
params["username"] = self.config["hub"]["username"] | |
params["token"] = self.config["hub"]["token"] | |
params["push_to_hub"] = self.config["hub"]["push_to_hub"] | |
else: | |
params["username"] = None | |
params["token"] = None | |
params["push_to_hub"] = False | |
if params["username"]: | |
if params["username"].startswith("${"): | |
params["username"] = os.environ.get(params["username"][2:-1]) | |
if params["token"]: | |
if params["token"].startswith("${"): | |
params["token"] = os.environ.get(params["token"][2:-1]) | |
other_params = self.config.get("params") | |
if other_params: | |
params.update(other_params) | |
return params | |
def run(self): | |
_params = self.task_param_map[self.task](**self.parsed_config) | |
logger.info(_params) | |
_munge_fn = self.munge_data_map[self.task] | |
_munge_fn(_params, local=self.backend.startswith("local")) | |
project = AutoTrainProject(params=_params, backend=self.backend) | |
job_id = project.create() | |
logger.info(f"Job ID: {job_id}") | |