Spaces:
Build error
Build error
File size: 4,332 Bytes
28451f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from os.path import join, realpath
import sys
try:
from typing import override
except ImportError:
def override(f):
return f
from loguru import logger as log
import numpy as np
from multi_gpu import MultiGPUInferenceAR
from server_base import ROOT_DIR
from server_cosmos_base import CosmosBaseModel
COSMOS_PREDICT1_ROOT = ROOT_DIR
TORCHRUN_DEFAULT_MASTER_ADDR = 'localhost'
TORCHRUN_DEFAULT_MASTER_PORT = 12355
def add_cosmos_venv_to_path():
version_string = f"python{sys.version_info.major}.{sys.version_info.minor}"
extras = [
COSMOS_PREDICT1_ROOT,
join(COSMOS_PREDICT1_ROOT, "cosmos_predict1"),
]
for e in extras:
if e not in sys.path:
sys.path.append(e)
class CosmosModel(CosmosBaseModel):
"""
Serves frames generated on-the-fly by the Cosmos generative model.
Intended for use with the Cosmos-Predict-1 based Gen3C model.
"""
def __init__(self, gpu_count: int = 0, **kwargs):
add_cosmos_venv_to_path()
if not os.environ.get("HF_HOME"):
os.environ["HF_HOME"] = join(COSMOS_PREDICT1_ROOT, "huggingface_home")
super().__init__(**kwargs)
assert os.path.isdir(join(COSMOS_PREDICT1_ROOT, "cosmos_predict1")), \
f"Could not find Cosmos (cosmos_predict1) directory at: {COSMOS_PREDICT1_ROOT}"
from cosmos_predict1.diffusion.inference.gen3c_persistent import Gen3cPersistentModel, create_parser
import torch
if gpu_count == 0:
# Use as many GPUs for inference as are available on this machine.
gpu_count = torch.cuda.device_count()
# Note: we use the argparse-based interface so that all defaults are preserved.
parser = create_parser()
common_args = [
"--checkpoint_dir", self.checkpoint_path or join(COSMOS_PREDICT1_ROOT, "checkpoints"),
"--video_save_name=", # Empty string
"--video_save_folder", join(COSMOS_PREDICT1_ROOT, "outputs"),
"--trajectory", "none",
"--prompt=", # Empty string
"--negative_prompt=", # Empty string
"--offload_prompt_upsampler",
"--disable_prompt_upsampler",
"--disable_guardrail",
"--num_gpus", str(gpu_count),
"--guidance", "1.0",
"--num_video_frames", "121",
"--foreground_masking",
]
args = parser.parse_args(common_args)
if gpu_count == 1:
self.model = Gen3cPersistentModel(args)
else:
log.info(f"Loading Cosmos-Predict1 inference model on {gpu_count} GPUs.")
self.model = MultiGPUInferenceAR(gpu_count, cosmos_variant="predict1", args=args)
# Since the model may require overlap of inference batches,
# we save previous inference poses so that we can provide any number of
# previous camera poses when starting the next inference batch.
# TODO: ensure some kind of ordering?
self.pose_history_w2c: list[np.array] = []
self.intrinsics_history: list[np.array] = []
self.default_focal_length = (338.29, 338.29)
self.default_principal_point = (0.5, 0.5)
self.aabb_min = np.array([-16, -16, -16])
self.aabb_max = np.array([16, 16, 16])
def inference_resolution(self) -> list[tuple[int, int]] | None:
"""The supported inference resolutions, or None if any resolution is supported."""
return [(1280, 704),]
@override
def max_frames_per_request(self) -> int:
# Not actually tested, but anyway we can expect autoregressive
# generation to go wrong earlier than this.
return self.model.frames_per_batch * 100
@override
def default_framerate(self) -> float:
return 24.0
def cleanup(self):
if isinstance(self.model, MultiGPUInferenceAR):
self.model.cleanup()
@override
def metadata(self) -> dict:
result = super().metadata()
result["model_name"] = "CosmosModel"
return result
if __name__ == "__main__":
model = CosmosModel()
|