Spaces:
Sleeping
Sleeping
File size: 6,587 Bytes
33d4721 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import json
from dataclasses import dataclass
from typing import Union
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams
from autotrain.trainers.generic.params import GenericParams
from autotrain.trainers.image_classification.params import ImageClassificationParams
from autotrain.trainers.image_regression.params import ImageRegressionParams
from autotrain.trainers.object_detection.params import ObjectDetectionParams
from autotrain.trainers.sent_transformers.params import SentenceTransformersParams
from autotrain.trainers.seq2seq.params import Seq2SeqParams
from autotrain.trainers.tabular.params import TabularParams
from autotrain.trainers.text_classification.params import TextClassificationParams
from autotrain.trainers.text_regression.params import TextRegressionParams
from autotrain.trainers.token_classification.params import TokenClassificationParams
from autotrain.trainers.vlm.params import VLMTrainingParams
AVAILABLE_HARDWARE = {
# hugging face spaces
"spaces-a10g-large": "a10g-large",
"spaces-a10g-small": "a10g-small",
"spaces-a100-large": "a100-large",
"spaces-t4-medium": "t4-medium",
"spaces-t4-small": "t4-small",
"spaces-cpu-upgrade": "cpu-upgrade",
"spaces-cpu-basic": "cpu-basic",
"spaces-l4x1": "l4x1",
"spaces-l4x4": "l4x4",
"spaces-l40sx1": "l40sx1",
"spaces-l40sx4": "l40sx4",
"spaces-l40sx8": "l40sx8",
"spaces-a10g-largex2": "a10g-largex2",
"spaces-a10g-largex4": "a10g-largex4",
# ngc
"dgx-a100": "dgxa100.80g.1.norm",
"dgx-2a100": "dgxa100.80g.2.norm",
"dgx-4a100": "dgxa100.80g.4.norm",
"dgx-8a100": "dgxa100.80g.8.norm",
# hugging face endpoints
"ep-aws-useast1-s": "aws_us-east-1_gpu_small_g4dn.xlarge",
"ep-aws-useast1-m": "aws_us-east-1_gpu_medium_g5.2xlarge",
"ep-aws-useast1-l": "aws_us-east-1_gpu_large_g4dn.12xlarge",
"ep-aws-useast1-xl": "aws_us-east-1_gpu_xlarge_p4de",
"ep-aws-useast1-2xl": "aws_us-east-1_gpu_2xlarge_p4de",
"ep-aws-useast1-4xl": "aws_us-east-1_gpu_4xlarge_p4de",
"ep-aws-useast1-8xl": "aws_us-east-1_gpu_8xlarge_p4de",
# nvcf
"nvcf-l40sx1": {"id": "67bb8939-c932-429a-a446-8ae898311856"},
"nvcf-h100x1": {"id": "848348f8-a4e2-4242-bce9-6baa1bd70a66"},
"nvcf-h100x2": {"id": "fb006a89-451e-4d9c-82b5-33eff257e0bf"},
"nvcf-h100x4": {"id": "21bae5af-87e5-4132-8fc0-bf3084e59a57"},
"nvcf-h100x8": {"id": "6e0c2af6-5368-47e0-b15e-c070c2c92018"},
# local
"local-ui": "local",
"local": "local",
"local-cli": "local",
}
@dataclass
class BaseBackend:
"""
BaseBackend class is responsible for initializing and validating backend configurations
for various training parameters. It supports multiple types of training parameters
including text classification, image classification, LLM training, and more.
Attributes:
params (Union[TextClassificationParams, ImageClassificationParams, LLMTrainingParams,
GenericParams, TabularParams, Seq2SeqParams,
TokenClassificationParams, TextRegressionParams, ObjectDetectionParams,
SentenceTransformersParams, ImageRegressionParams, VLMTrainingParams,
ExtractiveQuestionAnsweringParams]): Training parameters.
backend (str): Backend type.
Methods:
__post_init__(): Initializes the backend configuration, validates parameters,
sets task IDs, and prepares environment variables.
"""
params: Union[
TextClassificationParams,
ImageClassificationParams,
LLMTrainingParams,
GenericParams,
TabularParams,
Seq2SeqParams,
TokenClassificationParams,
TextRegressionParams,
ObjectDetectionParams,
SentenceTransformersParams,
ImageRegressionParams,
VLMTrainingParams,
ExtractiveQuestionAnsweringParams,
]
backend: str
def __post_init__(self):
self.username = None
if isinstance(self.params, GenericParams) and self.backend.startswith("local"):
raise ValueError("Local backend is not supported for GenericParams")
if (
self.backend.startswith("spaces-")
or self.backend.startswith("ep-")
or self.backend.startswith("ngc-")
or self.backend.startswith("nvcf-")
):
if self.params.username is not None:
self.username = self.params.username
else:
raise ValueError("Must provide username")
if isinstance(self.params, LLMTrainingParams):
self.task_id = 9
elif isinstance(self.params, TextClassificationParams):
self.task_id = 2
elif isinstance(self.params, TabularParams):
self.task_id = 26
elif isinstance(self.params, GenericParams):
self.task_id = 27
elif isinstance(self.params, Seq2SeqParams):
self.task_id = 28
elif isinstance(self.params, ImageClassificationParams):
self.task_id = 18
elif isinstance(self.params, TokenClassificationParams):
self.task_id = 4
elif isinstance(self.params, TextRegressionParams):
self.task_id = 10
elif isinstance(self.params, ObjectDetectionParams):
self.task_id = 29
elif isinstance(self.params, SentenceTransformersParams):
self.task_id = 30
elif isinstance(self.params, ImageRegressionParams):
self.task_id = 24
elif isinstance(self.params, VLMTrainingParams):
self.task_id = 31
elif isinstance(self.params, ExtractiveQuestionAnsweringParams):
self.task_id = 5
else:
raise NotImplementedError
self.available_hardware = AVAILABLE_HARDWARE
self.wait = False
if self.backend == "local-ui":
self.wait = False
if self.backend in ("local", "local-cli"):
self.wait = True
self.env_vars = {
"HF_TOKEN": self.params.token,
"AUTOTRAIN_USERNAME": self.username,
"PROJECT_NAME": self.params.project_name,
"TASK_ID": str(self.task_id),
"PARAMS": json.dumps(self.params.model_dump_json()),
}
self.env_vars["DATA_PATH"] = self.params.data_path
if not isinstance(self.params, GenericParams):
self.env_vars["MODEL"] = self.params.model
|