File size: 4,110 Bytes
6d7fc1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32a4d55
 
6d7fc1c
e22a639
 
 
 
226c7c9
32a4d55
 
226c7c9
e22a639
32a4d55
226c7c9
e22a639
32a4d55
 
e22a639
 
226c7c9
 
e22a639
 
226c7c9
 
e22a639
 
32a4d55
 
e22a639
226c7c9
32a4d55
226c7c9
 
e22a639
 
6d7fc1c
 
 
 
 
 
 
 
 
 
 
226c7c9
6d7fc1c
 
 
 
 
 
 
 
 
226c7c9
 
6d7fc1c
 
 
 
 
 
 
 
 
 
 
226c7c9
 
 
6d7fc1c
 
 
 
 
 
 
 
 
226c7c9
6d7fc1c
226c7c9
 
6d7fc1c
 
 
 
226c7c9
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import importlib
import os
import sys

PIP = f"{sys.executable} -m pip"
print(PIP)

def setup_environment():
    os.system("apt-get update && apt-get install -qqy libmagickwand-dev")

    # install packages
    # os.system(
    #     f'export FLASH_ATTENTION_SKIP_CUDA_BUILD=FALSE && \
    #     {PIP} install --timeout=1000000000 --no-build-isolation "flash-attn<=2.7.4.post1"'
    # )
    os.system(
        f"{PIP} install --timeout=1000000000 \
        https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.5%2Bcu128torch2.7-cp38-abi3-linux_x86_64.whl"
    )
    os.system(f'export VLLM_ATTENTION_BACKEND=FLASHINFER && {PIP} install "vllm==0.9.0"')
    os.system(f'{PIP} install "decord==0.6.0"')

    os.system(
        "export CONDA_PREFIX=/usr/local/cuda && \
        ln -sf $CONDA_PREFIX/lib/python3.10/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/"
    )
    os.system(
        "export CONDA_PREFIX=/usr/local/cuda && \
        ln -sf $CONDA_PREFIX/lib/python3.10/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/python3.10"
    )

    os.system(f'{PIP} install --timeout=1000000000 --no-build-isolation "transformer-engine[pytorch]"')
    os.system(f'{PIP} install --timeout=1000000000 "decord==0.6.0"')

    # os.system(
    #     f'{PIP} install --timeout=1000000000 \
    #     "git+https://github.com/nvidia-cosmos/cosmos-transfer1@e4055e39ee9c53165e85275bdab84ed20909714a"'
    # )


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--training",
        action="store_true",
        help="Whether to check training-specific dependencies",
    )
    return parser.parse_args()


def check_packages(package_list):
    all_success = True
    for package in package_list:
        try:
            _ = importlib.import_module(package)
        except Exception:
            print(f"\033[91m[ERROR]\033[0m Package not successfully imported: \033[93m{package}\033[0m")
            all_success = False
        else:
            print(f"\033[92m[SUCCESS]\033[0m {package} found")

    return all_success


def main():
    args = parse_args()

    if not (sys.version_info.major == 3 and sys.version_info.minor >= 10):
        detected = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
        print(f"\033[91m[ERROR]\033[0m Python 3.10+ is required. You have: \033[93m{detected}\033[0m")
        sys.exit(1)

    if "CONDA_PREFIX" not in os.environ:
        print(
            "\033[93m[WARNING]\033[0m CONDA_PREFIX is not set. "
            "When manually installed, Cosmos should run under the cosmos-transfer1 conda environment (see INSTALL.md). "
            "This warning can be ignored when running in the container."
        )

    print("Attempting to import critical packages...")

    packages = ["torch", "torchvision", "transformers", "megatron.core", "transformer_engine", "vllm", "pandas"]
    packages_training = [
        "apex.multi_tensor_apply",
    ]

    all_success = check_packages(packages)
    if args.training:
        if not check_packages(packages_training):
            all_success = False

    if all_success:
        print("-----------------------------------------------------------")
        print("\033[92m[SUCCESS]\033[0m Cosmos environment setup is successful!")

    return all_success


if __name__ == "__main__":
    print(f"Enivornment check success ? {main()}")