File size: 4,332 Bytes
28451f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
# http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from os.path import join, realpath
import sys
try:
    from typing import override
except ImportError:
	def override(f):
		return f

from loguru import logger as log
import numpy as np

from multi_gpu import MultiGPUInferenceAR
from server_base import ROOT_DIR
from server_cosmos_base import CosmosBaseModel

COSMOS_PREDICT1_ROOT = ROOT_DIR

TORCHRUN_DEFAULT_MASTER_ADDR = 'localhost'
TORCHRUN_DEFAULT_MASTER_PORT = 12355


def add_cosmos_venv_to_path():
	version_string = f"python{sys.version_info.major}.{sys.version_info.minor}"
	extras = [
		COSMOS_PREDICT1_ROOT,
		join(COSMOS_PREDICT1_ROOT, "cosmos_predict1"),
	]
	for e in extras:
		if e not in sys.path:
			sys.path.append(e)


class CosmosModel(CosmosBaseModel):
	"""
	Serves frames generated on-the-fly by the Cosmos generative model.
	Intended for use with the Cosmos-Predict-1 based Gen3C model.
	"""

	def __init__(self, gpu_count: int = 0, **kwargs):
		add_cosmos_venv_to_path()
		if not os.environ.get("HF_HOME"):
			os.environ["HF_HOME"] = join(COSMOS_PREDICT1_ROOT, "huggingface_home")

		super().__init__(**kwargs)

		assert os.path.isdir(join(COSMOS_PREDICT1_ROOT, "cosmos_predict1")), \
			   f"Could not find Cosmos (cosmos_predict1) directory at: {COSMOS_PREDICT1_ROOT}"


		from cosmos_predict1.diffusion.inference.gen3c_persistent import Gen3cPersistentModel, create_parser
		import torch

		if gpu_count == 0:
			# Use as many GPUs for inference as are available on this machine.
			gpu_count = torch.cuda.device_count()

		# Note: we use the argparse-based interface so that all defaults are preserved.
		parser = create_parser()
		common_args = [
			"--checkpoint_dir", self.checkpoint_path or join(COSMOS_PREDICT1_ROOT, "checkpoints"),
			"--video_save_name=", # Empty string
			"--video_save_folder", join(COSMOS_PREDICT1_ROOT, "outputs"),
			"--trajectory", "none",
			"--prompt=", # Empty string
			"--negative_prompt=", # Empty string
			"--offload_prompt_upsampler",
			"--disable_prompt_upsampler",
			"--disable_guardrail",
			"--num_gpus", str(gpu_count),
			"--guidance", "1.0",
			"--num_video_frames", "121",
			"--foreground_masking",
		]
		args = parser.parse_args(common_args)

		if gpu_count == 1:
			self.model = Gen3cPersistentModel(args)
		else:
			log.info(f"Loading Cosmos-Predict1 inference model on {gpu_count} GPUs.")
			self.model = MultiGPUInferenceAR(gpu_count, cosmos_variant="predict1", args=args)

		# Since the model may require overlap of inference batches,
		# we save previous inference poses so that we can provide any number of
		# previous camera poses when starting the next inference batch.
		# TODO: ensure some kind of ordering?
		self.pose_history_w2c: list[np.array] = []
		self.intrinsics_history: list[np.array] = []

		self.default_focal_length = (338.29, 338.29)
		self.default_principal_point = (0.5, 0.5)
		self.aabb_min = np.array([-16, -16, -16])
		self.aabb_max = np.array([16, 16, 16])


	def inference_resolution(self) -> list[tuple[int, int]] | None:
		"""The supported inference resolutions, or None if any resolution is supported."""
		return [(1280, 704),]


	@override
	def max_frames_per_request(self) -> int:
		# Not actually tested, but anyway we can expect autoregressive
		# generation to go wrong earlier than this.
		return self.model.frames_per_batch * 100

	@override
	def default_framerate(self) -> float:
		return 24.0

	def cleanup(self):
		if isinstance(self.model, MultiGPUInferenceAR):
			self.model.cleanup()

	@override
	def metadata(self) -> dict:
		result = super().metadata()
		result["model_name"] = "CosmosModel"
		return result


if __name__ == "__main__":
	model = CosmosModel()