hardiktiwari's picture
Upload 244 files
33d4721 verified
import argparse
from autotrain import __version__, logger
from autotrain.cli.run_api import RunAutoTrainAPICommand
from autotrain.cli.run_app import RunAutoTrainAppCommand
from autotrain.cli.run_extractive_qa import RunAutoTrainExtractiveQACommand
from autotrain.cli.run_image_classification import RunAutoTrainImageClassificationCommand
from autotrain.cli.run_image_regression import RunAutoTrainImageRegressionCommand
from autotrain.cli.run_llm import RunAutoTrainLLMCommand
from autotrain.cli.run_object_detection import RunAutoTrainObjectDetectionCommand
from autotrain.cli.run_sent_tranformers import RunAutoTrainSentenceTransformersCommand
from autotrain.cli.run_seq2seq import RunAutoTrainSeq2SeqCommand
from autotrain.cli.run_setup import RunSetupCommand
from autotrain.cli.run_spacerunner import RunAutoTrainSpaceRunnerCommand
from autotrain.cli.run_tabular import RunAutoTrainTabularCommand
from autotrain.cli.run_text_classification import RunAutoTrainTextClassificationCommand
from autotrain.cli.run_text_regression import RunAutoTrainTextRegressionCommand
from autotrain.cli.run_token_classification import RunAutoTrainTokenClassificationCommand
from autotrain.cli.run_tools import RunAutoTrainToolsCommand
from autotrain.parser import AutoTrainConfigParser
def main():
parser = argparse.ArgumentParser(
"AutoTrain advanced CLI",
usage="autotrain <command> [<args>]",
epilog="For more information about a command, run: `autotrain <command> --help`",
)
parser.add_argument("--version", "-v", help="Display AutoTrain version", action="store_true")
parser.add_argument("--config", help="Optional configuration file", type=str)
commands_parser = parser.add_subparsers(help="commands")
# Register commands
RunAutoTrainAppCommand.register_subcommand(commands_parser)
RunAutoTrainLLMCommand.register_subcommand(commands_parser)
RunSetupCommand.register_subcommand(commands_parser)
RunAutoTrainAPICommand.register_subcommand(commands_parser)
RunAutoTrainTextClassificationCommand.register_subcommand(commands_parser)
RunAutoTrainImageClassificationCommand.register_subcommand(commands_parser)
RunAutoTrainTabularCommand.register_subcommand(commands_parser)
RunAutoTrainSpaceRunnerCommand.register_subcommand(commands_parser)
RunAutoTrainSeq2SeqCommand.register_subcommand(commands_parser)
RunAutoTrainTokenClassificationCommand.register_subcommand(commands_parser)
RunAutoTrainToolsCommand.register_subcommand(commands_parser)
RunAutoTrainTextRegressionCommand.register_subcommand(commands_parser)
RunAutoTrainObjectDetectionCommand.register_subcommand(commands_parser)
RunAutoTrainSentenceTransformersCommand.register_subcommand(commands_parser)
RunAutoTrainImageRegressionCommand.register_subcommand(commands_parser)
RunAutoTrainExtractiveQACommand.register_subcommand(commands_parser)
args = parser.parse_args()
if args.version:
print(__version__)
exit(0)
if args.config:
logger.info(f"Using AutoTrain configuration: {args.config}")
cp = AutoTrainConfigParser(args.config)
cp.run()
exit(0)
if not hasattr(args, "func"):
parser.print_help()
exit(1)
command = args.func(args)
command.run()
if __name__ == "__main__":
main()