gen3c / gui /api /server_cosmos.py
elungky's picture
Initial commit for new Space - pre-built Docker image
28451f7
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from os.path import join, realpath
import sys
try:
from typing import override
except ImportError:
def override(f):
return f
from loguru import logger as log
import numpy as np
from multi_gpu import MultiGPUInferenceAR
from server_base import ROOT_DIR
from server_cosmos_base import CosmosBaseModel
COSMOS_PREDICT1_ROOT = ROOT_DIR
TORCHRUN_DEFAULT_MASTER_ADDR = 'localhost'
TORCHRUN_DEFAULT_MASTER_PORT = 12355
def add_cosmos_venv_to_path():
version_string = f"python{sys.version_info.major}.{sys.version_info.minor}"
extras = [
COSMOS_PREDICT1_ROOT,
join(COSMOS_PREDICT1_ROOT, "cosmos_predict1"),
]
for e in extras:
if e not in sys.path:
sys.path.append(e)
class CosmosModel(CosmosBaseModel):
"""
Serves frames generated on-the-fly by the Cosmos generative model.
Intended for use with the Cosmos-Predict-1 based Gen3C model.
"""
def __init__(self, gpu_count: int = 0, **kwargs):
add_cosmos_venv_to_path()
if not os.environ.get("HF_HOME"):
os.environ["HF_HOME"] = join(COSMOS_PREDICT1_ROOT, "huggingface_home")
super().__init__(**kwargs)
assert os.path.isdir(join(COSMOS_PREDICT1_ROOT, "cosmos_predict1")), \
f"Could not find Cosmos (cosmos_predict1) directory at: {COSMOS_PREDICT1_ROOT}"
from cosmos_predict1.diffusion.inference.gen3c_persistent import Gen3cPersistentModel, create_parser
import torch
if gpu_count == 0:
# Use as many GPUs for inference as are available on this machine.
gpu_count = torch.cuda.device_count()
# Note: we use the argparse-based interface so that all defaults are preserved.
parser = create_parser()
common_args = [
"--checkpoint_dir", self.checkpoint_path or join(COSMOS_PREDICT1_ROOT, "checkpoints"),
"--video_save_name=", # Empty string
"--video_save_folder", join(COSMOS_PREDICT1_ROOT, "outputs"),
"--trajectory", "none",
"--prompt=", # Empty string
"--negative_prompt=", # Empty string
"--offload_prompt_upsampler",
"--disable_prompt_upsampler",
"--disable_guardrail",
"--num_gpus", str(gpu_count),
"--guidance", "1.0",
"--num_video_frames", "121",
"--foreground_masking",
]
args = parser.parse_args(common_args)
if gpu_count == 1:
self.model = Gen3cPersistentModel(args)
else:
log.info(f"Loading Cosmos-Predict1 inference model on {gpu_count} GPUs.")
self.model = MultiGPUInferenceAR(gpu_count, cosmos_variant="predict1", args=args)
# Since the model may require overlap of inference batches,
# we save previous inference poses so that we can provide any number of
# previous camera poses when starting the next inference batch.
# TODO: ensure some kind of ordering?
self.pose_history_w2c: list[np.array] = []
self.intrinsics_history: list[np.array] = []
self.default_focal_length = (338.29, 338.29)
self.default_principal_point = (0.5, 0.5)
self.aabb_min = np.array([-16, -16, -16])
self.aabb_max = np.array([16, 16, 16])
def inference_resolution(self) -> list[tuple[int, int]] | None:
"""The supported inference resolutions, or None if any resolution is supported."""
return [(1280, 704),]
@override
def max_frames_per_request(self) -> int:
# Not actually tested, but anyway we can expect autoregressive
# generation to go wrong earlier than this.
return self.model.frames_per_batch * 100
@override
def default_framerate(self) -> float:
return 24.0
def cleanup(self):
if isinstance(self.model, MultiGPUInferenceAR):
self.model.cleanup()
@override
def metadata(self) -> dict:
result = super().metadata()
result["model_name"] = "CosmosModel"
return result
if __name__ == "__main__":
model = CosmosModel()