Spaces:
Sleeping
Sleeping
File size: 1,810 Bytes
33d4721 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import subprocess
from argparse import ArgumentParser
from autotrain import logger
from . import BaseAutoTrainCommand
def run_app_command_factory(args):
return RunSetupCommand(args.update_torch, args.colab)
class RunSetupCommand(BaseAutoTrainCommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
run_setup_parser = parser.add_parser(
"setup",
description="✨ Run AutoTrain setup",
)
run_setup_parser.add_argument(
"--update-torch",
action="store_true",
help="Update PyTorch to latest version",
)
run_setup_parser.add_argument(
"--colab",
action="store_true",
help="Run setup for Google Colab",
)
run_setup_parser.set_defaults(func=run_app_command_factory)
def __init__(self, update_torch: bool, colab: bool = False):
self.update_torch = update_torch
self.colab = colab
def run(self):
if self.colab:
cmd = "pip install -U xformers==0.0.24"
else:
cmd = "pip uninstall -y xformers"
cmd = cmd.split()
pipe = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
logger.info("Installing latest xformers")
_, _ = pipe.communicate()
logger.info("Successfully installed latest xformers")
if self.update_torch:
cmd = "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121"
cmd = cmd.split()
pipe = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
logger.info("Installing latest PyTorch")
_, _ = pipe.communicate()
logger.info("Successfully installed latest PyTorch")
|