Spaces:
Sleeping
Sleeping
File size: 5,236 Bytes
33d4721 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
from argparse import ArgumentParser
from autotrain import logger
from autotrain.backends.base import AVAILABLE_HARDWARE
from autotrain.backends.spaces import SpaceRunner
from autotrain.trainers.generic.params import GenericParams
from autotrain.trainers.generic.utils import create_dataset_repo
from . import BaseAutoTrainCommand
BACKEND_CHOICES = list(AVAILABLE_HARDWARE.keys())
BACKEND_CHOICES = [b for b in BACKEND_CHOICES if b.startswith("spaces-")]
def run_spacerunner_command_factory(args):
return RunAutoTrainSpaceRunnerCommand(args)
class RunAutoTrainSpaceRunnerCommand(BaseAutoTrainCommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
arg_list = [
{
"arg": "--project-name",
"help": "Name of the project. Must be unique.",
"required": True,
"type": str,
},
{
"arg": "--script-path",
"help": "Path to the script",
"required": True,
"type": str,
},
{
"arg": "--username",
"help": "Hugging Face Username, can also be an organization name",
"required": True,
"type": str,
},
{
"arg": "--token",
"help": "Hugging Face API Token",
"required": True,
"type": str,
},
{
"arg": "--backend",
"help": "Hugging Face backend to use",
"required": True,
"type": str,
"choices": BACKEND_CHOICES,
},
{
"arg": "--env",
"help": "Environment variables, e.g. --env FOO=bar;FOO2=bar2;FOO3=bar3",
"required": False,
"type": str,
},
{
"arg": "--args",
"help": "Arguments to pass to the script, e.g. --args foo=bar;foo2=bar2;foo3=bar3;store_true_arg",
"required": False,
"type": str,
},
]
run_spacerunner_parser = parser.add_parser("spacerunner", description="✨ Run AutoTrain SpaceRunner")
for arg in arg_list:
names = [arg["arg"]] + arg.get("alias", [])
if "action" in arg:
run_spacerunner_parser.add_argument(
*names,
dest=arg["arg"].replace("--", "").replace("-", "_"),
help=arg["help"],
required=arg.get("required", False),
action=arg.get("action"),
default=arg.get("default"),
choices=arg.get("choices"),
)
else:
run_spacerunner_parser.add_argument(
*names,
dest=arg["arg"].replace("--", "").replace("-", "_"),
help=arg["help"],
required=arg.get("required", False),
type=arg.get("type"),
default=arg.get("default"),
choices=arg.get("choices"),
)
run_spacerunner_parser.set_defaults(func=run_spacerunner_command_factory)
def __init__(self, args):
self.args = args
store_true_arg_names = []
for arg_name in store_true_arg_names:
if getattr(self.args, arg_name) is None:
setattr(self.args, arg_name, False)
env_vars = {}
if self.args.env:
for env_name_value in self.args.env.split(";"):
if len(env_name_value.split("=")) == 2:
env_vars[env_name_value.split("=")[0]] = env_name_value.split("=")[1]
else:
raise ValueError("Invalid environment variable format.")
self.args.env = env_vars
app_args = {}
store_true_args = []
if self.args.args:
for arg_name_value in self.args.args.split(";"):
if len(arg_name_value.split("=")) == 1:
store_true_args.append(arg_name_value)
elif len(arg_name_value.split("=")) == 2:
app_args[arg_name_value.split("=")[0]] = arg_name_value.split("=")[1]
else:
raise ValueError("Invalid argument format.")
for arg_name in store_true_args:
app_args[arg_name] = ""
self.args.args = app_args
def run(self):
dataset_id = create_dataset_repo(
username=self.args.username,
project_name=self.args.project_name,
script_path=self.args.script_path,
token=self.args.token,
)
params = GenericParams(
project_name=self.args.project_name,
data_path=dataset_id,
username=self.args.username,
token=self.args.token,
script_path=self.args.script_path,
env=self.args.env,
args=self.args.args,
)
project = SpaceRunner(params=params, backend=self.args.backend)
job_id = project.create()
logger.info(f"Job ID: {job_id}")
|