hardiktiwari's picture
Upload 244 files
33d4721 verified
import json
from dataclasses import dataclass
from typing import Union
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams
from autotrain.trainers.generic.params import GenericParams
from autotrain.trainers.image_classification.params import ImageClassificationParams
from autotrain.trainers.image_regression.params import ImageRegressionParams
from autotrain.trainers.object_detection.params import ObjectDetectionParams
from autotrain.trainers.sent_transformers.params import SentenceTransformersParams
from autotrain.trainers.seq2seq.params import Seq2SeqParams
from autotrain.trainers.tabular.params import TabularParams
from autotrain.trainers.text_classification.params import TextClassificationParams
from autotrain.trainers.text_regression.params import TextRegressionParams
from autotrain.trainers.token_classification.params import TokenClassificationParams
from autotrain.trainers.vlm.params import VLMTrainingParams
AVAILABLE_HARDWARE = {
# hugging face spaces
"spaces-a10g-large": "a10g-large",
"spaces-a10g-small": "a10g-small",
"spaces-a100-large": "a100-large",
"spaces-t4-medium": "t4-medium",
"spaces-t4-small": "t4-small",
"spaces-cpu-upgrade": "cpu-upgrade",
"spaces-cpu-basic": "cpu-basic",
"spaces-l4x1": "l4x1",
"spaces-l4x4": "l4x4",
"spaces-l40sx1": "l40sx1",
"spaces-l40sx4": "l40sx4",
"spaces-l40sx8": "l40sx8",
"spaces-a10g-largex2": "a10g-largex2",
"spaces-a10g-largex4": "a10g-largex4",
# ngc
"dgx-a100": "dgxa100.80g.1.norm",
"dgx-2a100": "dgxa100.80g.2.norm",
"dgx-4a100": "dgxa100.80g.4.norm",
"dgx-8a100": "dgxa100.80g.8.norm",
# hugging face endpoints
"ep-aws-useast1-s": "aws_us-east-1_gpu_small_g4dn.xlarge",
"ep-aws-useast1-m": "aws_us-east-1_gpu_medium_g5.2xlarge",
"ep-aws-useast1-l": "aws_us-east-1_gpu_large_g4dn.12xlarge",
"ep-aws-useast1-xl": "aws_us-east-1_gpu_xlarge_p4de",
"ep-aws-useast1-2xl": "aws_us-east-1_gpu_2xlarge_p4de",
"ep-aws-useast1-4xl": "aws_us-east-1_gpu_4xlarge_p4de",
"ep-aws-useast1-8xl": "aws_us-east-1_gpu_8xlarge_p4de",
# nvcf
"nvcf-l40sx1": {"id": "67bb8939-c932-429a-a446-8ae898311856"},
"nvcf-h100x1": {"id": "848348f8-a4e2-4242-bce9-6baa1bd70a66"},
"nvcf-h100x2": {"id": "fb006a89-451e-4d9c-82b5-33eff257e0bf"},
"nvcf-h100x4": {"id": "21bae5af-87e5-4132-8fc0-bf3084e59a57"},
"nvcf-h100x8": {"id": "6e0c2af6-5368-47e0-b15e-c070c2c92018"},
# local
"local-ui": "local",
"local": "local",
"local-cli": "local",
}
@dataclass
class BaseBackend:
"""
BaseBackend class is responsible for initializing and validating backend configurations
for various training parameters. It supports multiple types of training parameters
including text classification, image classification, LLM training, and more.
Attributes:
params (Union[TextClassificationParams, ImageClassificationParams, LLMTrainingParams,
GenericParams, TabularParams, Seq2SeqParams,
TokenClassificationParams, TextRegressionParams, ObjectDetectionParams,
SentenceTransformersParams, ImageRegressionParams, VLMTrainingParams,
ExtractiveQuestionAnsweringParams]): Training parameters.
backend (str): Backend type.
Methods:
__post_init__(): Initializes the backend configuration, validates parameters,
sets task IDs, and prepares environment variables.
"""
params: Union[
TextClassificationParams,
ImageClassificationParams,
LLMTrainingParams,
GenericParams,
TabularParams,
Seq2SeqParams,
TokenClassificationParams,
TextRegressionParams,
ObjectDetectionParams,
SentenceTransformersParams,
ImageRegressionParams,
VLMTrainingParams,
ExtractiveQuestionAnsweringParams,
]
backend: str
def __post_init__(self):
self.username = None
if isinstance(self.params, GenericParams) and self.backend.startswith("local"):
raise ValueError("Local backend is not supported for GenericParams")
if (
self.backend.startswith("spaces-")
or self.backend.startswith("ep-")
or self.backend.startswith("ngc-")
or self.backend.startswith("nvcf-")
):
if self.params.username is not None:
self.username = self.params.username
else:
raise ValueError("Must provide username")
if isinstance(self.params, LLMTrainingParams):
self.task_id = 9
elif isinstance(self.params, TextClassificationParams):
self.task_id = 2
elif isinstance(self.params, TabularParams):
self.task_id = 26
elif isinstance(self.params, GenericParams):
self.task_id = 27
elif isinstance(self.params, Seq2SeqParams):
self.task_id = 28
elif isinstance(self.params, ImageClassificationParams):
self.task_id = 18
elif isinstance(self.params, TokenClassificationParams):
self.task_id = 4
elif isinstance(self.params, TextRegressionParams):
self.task_id = 10
elif isinstance(self.params, ObjectDetectionParams):
self.task_id = 29
elif isinstance(self.params, SentenceTransformersParams):
self.task_id = 30
elif isinstance(self.params, ImageRegressionParams):
self.task_id = 24
elif isinstance(self.params, VLMTrainingParams):
self.task_id = 31
elif isinstance(self.params, ExtractiveQuestionAnsweringParams):
self.task_id = 5
else:
raise NotImplementedError
self.available_hardware = AVAILABLE_HARDWARE
self.wait = False
if self.backend == "local-ui":
self.wait = False
if self.backend in ("local", "local-cli"):
self.wait = True
self.env_vars = {
"HF_TOKEN": self.params.token,
"AUTOTRAIN_USERNAME": self.username,
"PROJECT_NAME": self.params.project_name,
"TASK_ID": str(self.task_id),
"PARAMS": json.dumps(self.params.model_dump_json()),
}
self.env_vars["DATA_PATH"] = self.params.data_path
if not isinstance(self.params, GenericParams):
self.env_vars["MODEL"] = self.params.model