hardiktiwari's picture
Upload 244 files
33d4721 verified
import json
import os
import subprocess
from autotrain.commands import launch_command
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams
from autotrain.trainers.generic.params import GenericParams
from autotrain.trainers.image_classification.params import ImageClassificationParams
from autotrain.trainers.image_regression.params import ImageRegressionParams
from autotrain.trainers.object_detection.params import ObjectDetectionParams
from autotrain.trainers.sent_transformers.params import SentenceTransformersParams
from autotrain.trainers.seq2seq.params import Seq2SeqParams
from autotrain.trainers.tabular.params import TabularParams
from autotrain.trainers.text_classification.params import TextClassificationParams
from autotrain.trainers.text_regression.params import TextRegressionParams
from autotrain.trainers.token_classification.params import TokenClassificationParams
from autotrain.trainers.vlm.params import VLMTrainingParams
ALLOW_REMOTE_CODE = os.environ.get("ALLOW_REMOTE_CODE", "true").lower() == "true"
def run_training(params, task_id, local=False, wait=False):
"""
Run the training process based on the provided parameters and task ID.
Args:
params (str): JSON string of the parameters required for training.
task_id (int): Identifier for the type of task to be performed.
local (bool, optional): Flag to indicate if the training should be run locally. Defaults to False.
wait (bool, optional): Flag to indicate if the function should wait for the process to complete. Defaults to False.
Returns:
int: Process ID of the launched training process.
Raises:
NotImplementedError: If the task_id does not match any of the predefined tasks.
"""
params = json.loads(params)
if isinstance(params, str):
params = json.loads(params)
if task_id == 9:
params = LLMTrainingParams(**params)
elif task_id == 28:
params = Seq2SeqParams(**params)
elif task_id in (1, 2):
params = TextClassificationParams(**params)
elif task_id in (13, 14, 15, 16, 26):
params = TabularParams(**params)
elif task_id == 27:
params = GenericParams(**params)
elif task_id == 18:
params = ImageClassificationParams(**params)
elif task_id == 4:
params = TokenClassificationParams(**params)
elif task_id == 10:
params = TextRegressionParams(**params)
elif task_id == 29:
params = ObjectDetectionParams(**params)
elif task_id == 30:
params = SentenceTransformersParams(**params)
elif task_id == 24:
params = ImageRegressionParams(**params)
elif task_id == 31:
params = VLMTrainingParams(**params)
elif task_id == 5:
params = ExtractiveQuestionAnsweringParams(**params)
else:
raise NotImplementedError
params.save(output_dir=params.project_name)
cmd = launch_command(params=params)
cmd = [str(c) for c in cmd]
env = os.environ.copy()
process = subprocess.Popen(cmd, env=env)
if wait:
process.wait()
return process.pid