Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
from os.path import join, realpath | |
import sys | |
try: | |
from typing import override | |
except ImportError: | |
def override(f): | |
return f | |
from loguru import logger as log | |
import numpy as np | |
from multi_gpu import MultiGPUInferenceAR | |
from server_base import ROOT_DIR | |
from server_cosmos_base import CosmosBaseModel | |
COSMOS_PREDICT1_ROOT = ROOT_DIR | |
TORCHRUN_DEFAULT_MASTER_ADDR = 'localhost' | |
TORCHRUN_DEFAULT_MASTER_PORT = 12355 | |
def add_cosmos_venv_to_path(): | |
version_string = f"python{sys.version_info.major}.{sys.version_info.minor}" | |
extras = [ | |
COSMOS_PREDICT1_ROOT, | |
join(COSMOS_PREDICT1_ROOT, "cosmos_predict1"), | |
] | |
for e in extras: | |
if e not in sys.path: | |
sys.path.append(e) | |
class CosmosModel(CosmosBaseModel): | |
""" | |
Serves frames generated on-the-fly by the Cosmos generative model. | |
Intended for use with the Cosmos-Predict-1 based Gen3C model. | |
""" | |
def __init__(self, gpu_count: int = 0, **kwargs): | |
add_cosmos_venv_to_path() | |
if not os.environ.get("HF_HOME"): | |
os.environ["HF_HOME"] = join(COSMOS_PREDICT1_ROOT, "huggingface_home") | |
super().__init__(**kwargs) | |
assert os.path.isdir(join(COSMOS_PREDICT1_ROOT, "cosmos_predict1")), \ | |
f"Could not find Cosmos (cosmos_predict1) directory at: {COSMOS_PREDICT1_ROOT}" | |
from cosmos_predict1.diffusion.inference.gen3c_persistent import Gen3cPersistentModel, create_parser | |
import torch | |
if gpu_count == 0: | |
# Use as many GPUs for inference as are available on this machine. | |
gpu_count = torch.cuda.device_count() | |
# Note: we use the argparse-based interface so that all defaults are preserved. | |
parser = create_parser() | |
common_args = [ | |
"--checkpoint_dir", self.checkpoint_path or join(COSMOS_PREDICT1_ROOT, "checkpoints"), | |
"--video_save_name=", # Empty string | |
"--video_save_folder", join(COSMOS_PREDICT1_ROOT, "outputs"), | |
"--trajectory", "none", | |
"--prompt=", # Empty string | |
"--negative_prompt=", # Empty string | |
"--offload_prompt_upsampler", | |
"--disable_prompt_upsampler", | |
"--disable_guardrail", | |
"--num_gpus", str(gpu_count), | |
"--guidance", "1.0", | |
"--num_video_frames", "121", | |
"--foreground_masking", | |
] | |
args = parser.parse_args(common_args) | |
if gpu_count == 1: | |
self.model = Gen3cPersistentModel(args) | |
else: | |
log.info(f"Loading Cosmos-Predict1 inference model on {gpu_count} GPUs.") | |
self.model = MultiGPUInferenceAR(gpu_count, cosmos_variant="predict1", args=args) | |
# Since the model may require overlap of inference batches, | |
# we save previous inference poses so that we can provide any number of | |
# previous camera poses when starting the next inference batch. | |
# TODO: ensure some kind of ordering? | |
self.pose_history_w2c: list[np.array] = [] | |
self.intrinsics_history: list[np.array] = [] | |
self.default_focal_length = (338.29, 338.29) | |
self.default_principal_point = (0.5, 0.5) | |
self.aabb_min = np.array([-16, -16, -16]) | |
self.aabb_max = np.array([16, 16, 16]) | |
def inference_resolution(self) -> list[tuple[int, int]] | None: | |
"""The supported inference resolutions, or None if any resolution is supported.""" | |
return [(1280, 704),] | |
def max_frames_per_request(self) -> int: | |
# Not actually tested, but anyway we can expect autoregressive | |
# generation to go wrong earlier than this. | |
return self.model.frames_per_batch * 100 | |
def default_framerate(self) -> float: | |
return 24.0 | |
def cleanup(self): | |
if isinstance(self.model, MultiGPUInferenceAR): | |
self.model.cleanup() | |
def metadata(self) -> dict: | |
result = super().metadata() | |
result["model_name"] = "CosmosModel" | |
return result | |
if __name__ == "__main__": | |
model = CosmosModel() | |