File size: 5,236 Bytes
33d4721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from argparse import ArgumentParser

from autotrain import logger
from autotrain.backends.base import AVAILABLE_HARDWARE
from autotrain.backends.spaces import SpaceRunner
from autotrain.trainers.generic.params import GenericParams
from autotrain.trainers.generic.utils import create_dataset_repo

from . import BaseAutoTrainCommand


BACKEND_CHOICES = list(AVAILABLE_HARDWARE.keys())
BACKEND_CHOICES = [b for b in BACKEND_CHOICES if b.startswith("spaces-")]


def run_spacerunner_command_factory(args):
    return RunAutoTrainSpaceRunnerCommand(args)


class RunAutoTrainSpaceRunnerCommand(BaseAutoTrainCommand):
    @staticmethod
    def register_subcommand(parser: ArgumentParser):
        arg_list = [
            {
                "arg": "--project-name",
                "help": "Name of the project. Must be unique.",
                "required": True,
                "type": str,
            },
            {
                "arg": "--script-path",
                "help": "Path to the script",
                "required": True,
                "type": str,
            },
            {
                "arg": "--username",
                "help": "Hugging Face Username, can also be an organization name",
                "required": True,
                "type": str,
            },
            {
                "arg": "--token",
                "help": "Hugging Face API Token",
                "required": True,
                "type": str,
            },
            {
                "arg": "--backend",
                "help": "Hugging Face backend to use",
                "required": True,
                "type": str,
                "choices": BACKEND_CHOICES,
            },
            {
                "arg": "--env",
                "help": "Environment variables, e.g. --env FOO=bar;FOO2=bar2;FOO3=bar3",
                "required": False,
                "type": str,
            },
            {
                "arg": "--args",
                "help": "Arguments to pass to the script, e.g. --args foo=bar;foo2=bar2;foo3=bar3;store_true_arg",
                "required": False,
                "type": str,
            },
        ]
        run_spacerunner_parser = parser.add_parser("spacerunner", description="✨ Run AutoTrain SpaceRunner")
        for arg in arg_list:
            names = [arg["arg"]] + arg.get("alias", [])
            if "action" in arg:
                run_spacerunner_parser.add_argument(
                    *names,
                    dest=arg["arg"].replace("--", "").replace("-", "_"),
                    help=arg["help"],
                    required=arg.get("required", False),
                    action=arg.get("action"),
                    default=arg.get("default"),
                    choices=arg.get("choices"),
                )
            else:
                run_spacerunner_parser.add_argument(
                    *names,
                    dest=arg["arg"].replace("--", "").replace("-", "_"),
                    help=arg["help"],
                    required=arg.get("required", False),
                    type=arg.get("type"),
                    default=arg.get("default"),
                    choices=arg.get("choices"),
                )
        run_spacerunner_parser.set_defaults(func=run_spacerunner_command_factory)

    def __init__(self, args):
        self.args = args

        store_true_arg_names = []
        for arg_name in store_true_arg_names:
            if getattr(self.args, arg_name) is None:
                setattr(self.args, arg_name, False)

        env_vars = {}
        if self.args.env:
            for env_name_value in self.args.env.split(";"):
                if len(env_name_value.split("=")) == 2:
                    env_vars[env_name_value.split("=")[0]] = env_name_value.split("=")[1]
                else:
                    raise ValueError("Invalid environment variable format.")
        self.args.env = env_vars

        app_args = {}
        store_true_args = []
        if self.args.args:
            for arg_name_value in self.args.args.split(";"):
                if len(arg_name_value.split("=")) == 1:
                    store_true_args.append(arg_name_value)
                elif len(arg_name_value.split("=")) == 2:
                    app_args[arg_name_value.split("=")[0]] = arg_name_value.split("=")[1]
                else:
                    raise ValueError("Invalid argument format.")

        for arg_name in store_true_args:
            app_args[arg_name] = ""
        self.args.args = app_args

    def run(self):
        dataset_id = create_dataset_repo(
            username=self.args.username,
            project_name=self.args.project_name,
            script_path=self.args.script_path,
            token=self.args.token,
        )
        params = GenericParams(
            project_name=self.args.project_name,
            data_path=dataset_id,
            username=self.args.username,
            token=self.args.token,
            script_path=self.args.script_path,
            env=self.args.env,
            args=self.args.args,
        )
        project = SpaceRunner(params=params, backend=self.args.backend)
        job_id = project.create()
        logger.info(f"Job ID: {job_id}")