Spaces:
Sleeping
Sleeping
File size: 3,999 Bytes
33d4721 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from argparse import ArgumentParser
from autotrain import logger
from autotrain.cli.utils import get_field_info
from autotrain.project import AutoTrainProject
from autotrain.trainers.text_regression.params import TextRegressionParams
from . import BaseAutoTrainCommand
def run_text_regression_command_factory(args):
return RunAutoTrainTextRegressionCommand(args)
class RunAutoTrainTextRegressionCommand(BaseAutoTrainCommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
arg_list = get_field_info(TextRegressionParams)
arg_list = [
{
"arg": "--train",
"help": "Command to train the model",
"required": False,
"action": "store_true",
},
{
"arg": "--deploy",
"help": "Command to deploy the model (limited availability)",
"required": False,
"action": "store_true",
},
{
"arg": "--inference",
"help": "Command to run inference (limited availability)",
"required": False,
"action": "store_true",
},
{
"arg": "--backend",
"help": "Backend",
"required": False,
"type": str,
"default": "local",
},
] + arg_list
arg_list = [arg for arg in arg_list if arg["arg"] != "--disable-gradient-checkpointing"]
run_text_regression_parser = parser.add_parser(
"text-regression", description="✨ Run AutoTrain Text Regression"
)
for arg in arg_list:
names = [arg["arg"]] + arg.get("alias", [])
if "action" in arg:
run_text_regression_parser.add_argument(
*names,
dest=arg["arg"].replace("--", "").replace("-", "_"),
help=arg["help"],
required=arg.get("required", False),
action=arg.get("action"),
default=arg.get("default"),
)
else:
run_text_regression_parser.add_argument(
*names,
dest=arg["arg"].replace("--", "").replace("-", "_"),
help=arg["help"],
required=arg.get("required", False),
type=arg.get("type"),
default=arg.get("default"),
choices=arg.get("choices"),
)
run_text_regression_parser.set_defaults(func=run_text_regression_command_factory)
def __init__(self, args):
self.args = args
store_true_arg_names = [
"train",
"deploy",
"inference",
"auto_find_batch_size",
"push_to_hub",
]
for arg_name in store_true_arg_names:
if getattr(self.args, arg_name) is None:
setattr(self.args, arg_name, False)
if self.args.train:
if self.args.project_name is None:
raise ValueError("Project name must be specified")
if self.args.data_path is None:
raise ValueError("Data path must be specified")
if self.args.model is None:
raise ValueError("Model must be specified")
if self.args.push_to_hub:
if self.args.username is None:
raise ValueError("Username must be specified for push to hub")
else:
raise ValueError("Must specify --train, --deploy or --inference")
def run(self):
logger.info("Running Text Regression")
if self.args.train:
params = TextRegressionParams(**vars(self.args))
project = AutoTrainProject(params=params, backend=self.args.backend, process=True)
job_id = project.create()
logger.info(f"Job ID: {job_id}")
|