Spaces:
Sleeping
Sleeping
from argparse import ArgumentParser | |
from autotrain import logger | |
from autotrain.cli.utils import get_field_info | |
from autotrain.project import AutoTrainProject | |
from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams | |
from . import BaseAutoTrainCommand | |
def run_extractive_qa_command_factory(args): | |
return RunAutoTrainExtractiveQACommand(args) | |
class RunAutoTrainExtractiveQACommand(BaseAutoTrainCommand): | |
def register_subcommand(parser: ArgumentParser): | |
arg_list = get_field_info(ExtractiveQuestionAnsweringParams) | |
arg_list = [ | |
{ | |
"arg": "--train", | |
"help": "Command to train the model", | |
"required": False, | |
"action": "store_true", | |
}, | |
{ | |
"arg": "--deploy", | |
"help": "Command to deploy the model (limited availability)", | |
"required": False, | |
"action": "store_true", | |
}, | |
{ | |
"arg": "--inference", | |
"help": "Command to run inference (limited availability)", | |
"required": False, | |
"action": "store_true", | |
}, | |
{ | |
"arg": "--backend", | |
"help": "Backend to use for training", | |
"required": False, | |
"default": "local", | |
}, | |
] + arg_list | |
arg_list = [arg for arg in arg_list if arg["arg"] != "--disable-gradient-checkpointing"] | |
run_extractive_qa_parser = parser.add_parser( | |
"extractive-qa", description="✨ Run AutoTrain Extractive Question Answering" | |
) | |
for arg in arg_list: | |
names = [arg["arg"]] + arg.get("alias", []) | |
if "action" in arg: | |
run_extractive_qa_parser.add_argument( | |
*names, | |
dest=arg["arg"].replace("--", "").replace("-", "_"), | |
help=arg["help"], | |
required=arg.get("required", False), | |
action=arg.get("action"), | |
default=arg.get("default"), | |
) | |
else: | |
run_extractive_qa_parser.add_argument( | |
*names, | |
dest=arg["arg"].replace("--", "").replace("-", "_"), | |
help=arg["help"], | |
required=arg.get("required", False), | |
type=arg.get("type"), | |
default=arg.get("default"), | |
choices=arg.get("choices"), | |
) | |
run_extractive_qa_parser.set_defaults(func=run_extractive_qa_command_factory) | |
def __init__(self, args): | |
self.args = args | |
store_true_arg_names = [ | |
"train", | |
"deploy", | |
"inference", | |
"auto_find_batch_size", | |
"push_to_hub", | |
] | |
for arg_name in store_true_arg_names: | |
if getattr(self.args, arg_name) is None: | |
setattr(self.args, arg_name, False) | |
if self.args.train: | |
if self.args.project_name is None: | |
raise ValueError("Project name must be specified") | |
if self.args.data_path is None: | |
raise ValueError("Data path must be specified") | |
if self.args.model is None: | |
raise ValueError("Model must be specified") | |
if self.args.push_to_hub: | |
if self.args.username is None: | |
raise ValueError("Username must be specified for push to hub") | |
else: | |
raise ValueError("Must specify --train, --deploy or --inference") | |
def run(self): | |
logger.info("Running Extractive Question Answering") | |
if self.args.train: | |
params = ExtractiveQuestionAnsweringParams(**vars(self.args)) | |
project = AutoTrainProject(params=params, backend=self.args.backend, process=True) | |
job_id = project.create() | |
logger.info(f"Job ID: {job_id}") | |