File size: 3,345 Bytes
33d4721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse

from autotrain import __version__, logger
from autotrain.cli.run_api import RunAutoTrainAPICommand
from autotrain.cli.run_app import RunAutoTrainAppCommand
from autotrain.cli.run_extractive_qa import RunAutoTrainExtractiveQACommand
from autotrain.cli.run_image_classification import RunAutoTrainImageClassificationCommand
from autotrain.cli.run_image_regression import RunAutoTrainImageRegressionCommand
from autotrain.cli.run_llm import RunAutoTrainLLMCommand
from autotrain.cli.run_object_detection import RunAutoTrainObjectDetectionCommand
from autotrain.cli.run_sent_tranformers import RunAutoTrainSentenceTransformersCommand
from autotrain.cli.run_seq2seq import RunAutoTrainSeq2SeqCommand
from autotrain.cli.run_setup import RunSetupCommand
from autotrain.cli.run_spacerunner import RunAutoTrainSpaceRunnerCommand
from autotrain.cli.run_tabular import RunAutoTrainTabularCommand
from autotrain.cli.run_text_classification import RunAutoTrainTextClassificationCommand
from autotrain.cli.run_text_regression import RunAutoTrainTextRegressionCommand
from autotrain.cli.run_token_classification import RunAutoTrainTokenClassificationCommand
from autotrain.cli.run_tools import RunAutoTrainToolsCommand
from autotrain.parser import AutoTrainConfigParser


def main():
    parser = argparse.ArgumentParser(
        "AutoTrain advanced CLI",
        usage="autotrain <command> [<args>]",
        epilog="For more information about a command, run: `autotrain <command> --help`",
    )
    parser.add_argument("--version", "-v", help="Display AutoTrain version", action="store_true")
    parser.add_argument("--config", help="Optional configuration file", type=str)
    commands_parser = parser.add_subparsers(help="commands")

    # Register commands
    RunAutoTrainAppCommand.register_subcommand(commands_parser)
    RunAutoTrainLLMCommand.register_subcommand(commands_parser)
    RunSetupCommand.register_subcommand(commands_parser)
    RunAutoTrainAPICommand.register_subcommand(commands_parser)
    RunAutoTrainTextClassificationCommand.register_subcommand(commands_parser)
    RunAutoTrainImageClassificationCommand.register_subcommand(commands_parser)
    RunAutoTrainTabularCommand.register_subcommand(commands_parser)
    RunAutoTrainSpaceRunnerCommand.register_subcommand(commands_parser)
    RunAutoTrainSeq2SeqCommand.register_subcommand(commands_parser)
    RunAutoTrainTokenClassificationCommand.register_subcommand(commands_parser)
    RunAutoTrainToolsCommand.register_subcommand(commands_parser)
    RunAutoTrainTextRegressionCommand.register_subcommand(commands_parser)
    RunAutoTrainObjectDetectionCommand.register_subcommand(commands_parser)
    RunAutoTrainSentenceTransformersCommand.register_subcommand(commands_parser)
    RunAutoTrainImageRegressionCommand.register_subcommand(commands_parser)
    RunAutoTrainExtractiveQACommand.register_subcommand(commands_parser)

    args = parser.parse_args()

    if args.version:
        print(__version__)
        exit(0)

    if args.config:
        logger.info(f"Using AutoTrain configuration: {args.config}")
        cp = AutoTrainConfigParser(args.config)
        cp.run()
        exit(0)

    if not hasattr(args, "func"):
        parser.print_help()
        exit(1)

    command = args.func(args)
    command.run()


if __name__ == "__main__":
    main()