File size: 4,110 Bytes
6d7fc1c 32a4d55 6d7fc1c e22a639 226c7c9 32a4d55 226c7c9 e22a639 32a4d55 226c7c9 e22a639 32a4d55 e22a639 226c7c9 e22a639 226c7c9 e22a639 32a4d55 e22a639 226c7c9 32a4d55 226c7c9 e22a639 6d7fc1c 226c7c9 6d7fc1c 226c7c9 6d7fc1c 226c7c9 6d7fc1c 226c7c9 6d7fc1c 226c7c9 6d7fc1c 226c7c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import importlib
import os
import sys
PIP = f"{sys.executable} -m pip"
print(PIP)
def setup_environment():
os.system("apt-get update && apt-get install -qqy libmagickwand-dev")
# install packages
# os.system(
# f'export FLASH_ATTENTION_SKIP_CUDA_BUILD=FALSE && \
# {PIP} install --timeout=1000000000 --no-build-isolation "flash-attn<=2.7.4.post1"'
# )
os.system(
f"{PIP} install --timeout=1000000000 \
https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.5%2Bcu128torch2.7-cp38-abi3-linux_x86_64.whl"
)
os.system(f'export VLLM_ATTENTION_BACKEND=FLASHINFER && {PIP} install "vllm==0.9.0"')
os.system(f'{PIP} install "decord==0.6.0"')
os.system(
"export CONDA_PREFIX=/usr/local/cuda && \
ln -sf $CONDA_PREFIX/lib/python3.10/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/"
)
os.system(
"export CONDA_PREFIX=/usr/local/cuda && \
ln -sf $CONDA_PREFIX/lib/python3.10/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/python3.10"
)
os.system(f'{PIP} install --timeout=1000000000 --no-build-isolation "transformer-engine[pytorch]"')
os.system(f'{PIP} install --timeout=1000000000 "decord==0.6.0"')
# os.system(
# f'{PIP} install --timeout=1000000000 \
# "git+https://github.com/nvidia-cosmos/cosmos-transfer1@e4055e39ee9c53165e85275bdab84ed20909714a"'
# )
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--training",
action="store_true",
help="Whether to check training-specific dependencies",
)
return parser.parse_args()
def check_packages(package_list):
all_success = True
for package in package_list:
try:
_ = importlib.import_module(package)
except Exception:
print(f"\033[91m[ERROR]\033[0m Package not successfully imported: \033[93m{package}\033[0m")
all_success = False
else:
print(f"\033[92m[SUCCESS]\033[0m {package} found")
return all_success
def main():
args = parse_args()
if not (sys.version_info.major == 3 and sys.version_info.minor >= 10):
detected = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
print(f"\033[91m[ERROR]\033[0m Python 3.10+ is required. You have: \033[93m{detected}\033[0m")
sys.exit(1)
if "CONDA_PREFIX" not in os.environ:
print(
"\033[93m[WARNING]\033[0m CONDA_PREFIX is not set. "
"When manually installed, Cosmos should run under the cosmos-transfer1 conda environment (see INSTALL.md). "
"This warning can be ignored when running in the container."
)
print("Attempting to import critical packages...")
packages = ["torch", "torchvision", "transformers", "megatron.core", "transformer_engine", "vllm", "pandas"]
packages_training = [
"apex.multi_tensor_apply",
]
all_success = check_packages(packages)
if args.training:
if not check_packages(packages_training):
all_success = False
if all_success:
print("-----------------------------------------------------------")
print("\033[92m[SUCCESS]\033[0m Cosmos environment setup is successful!")
return all_success
if __name__ == "__main__":
print(f"Enivornment check success ? {main()}")
|