File size: 1,810 Bytes
33d4721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import subprocess
from argparse import ArgumentParser

from autotrain import logger

from . import BaseAutoTrainCommand


def run_app_command_factory(args):
    return RunSetupCommand(args.update_torch, args.colab)


class RunSetupCommand(BaseAutoTrainCommand):
    @staticmethod
    def register_subcommand(parser: ArgumentParser):
        run_setup_parser = parser.add_parser(
            "setup",
            description="✨ Run AutoTrain setup",
        )
        run_setup_parser.add_argument(
            "--update-torch",
            action="store_true",
            help="Update PyTorch to latest version",
        )
        run_setup_parser.add_argument(
            "--colab",
            action="store_true",
            help="Run setup for Google Colab",
        )
        run_setup_parser.set_defaults(func=run_app_command_factory)

    def __init__(self, update_torch: bool, colab: bool = False):
        self.update_torch = update_torch
        self.colab = colab

    def run(self):
        if self.colab:
            cmd = "pip install -U xformers==0.0.24"
        else:
            cmd = "pip uninstall -y xformers"
        cmd = cmd.split()
        pipe = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        logger.info("Installing latest xformers")
        _, _ = pipe.communicate()
        logger.info("Successfully installed latest xformers")

        if self.update_torch:
            cmd = "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121"
            cmd = cmd.split()
            pipe = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            logger.info("Installing latest PyTorch")
            _, _ = pipe.communicate()
            logger.info("Successfully installed latest PyTorch")