Spaces:
Build error
Build error
File size: 26,902 Bytes
b6af722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 |
import argparse
import os
import time
from moge.model.v1 import MoGeModel
import torch
import numpy as np
from cosmos_predict1.diffusion.inference.gen3c_pipeline import Gen3cPipeline
from cosmos_predict1.diffusion.inference.gen3c_single_image import (
create_parser as create_parser_base,
validate_args as validate_args_base,
_predict_moge_depth,
_predict_moge_depth_from_tensor
)
from cosmos_predict1.utils import log, misc
from cosmos_predict1.utils.distributed import device_with_rank, is_rank0, get_rank
from cosmos_predict1.utils.io import save_video
from cosmos_predict1.diffusion.inference.cache_3d import Cache3D_Buffer, Cache4D
import torch.nn.functional as F
def create_parser():
return create_parser_base()
def validate_args(args: argparse.Namespace):
validate_args_base(args)
assert args.batch_input_path is None, "Unsupported in persistent mode"
assert args.prompt is not None, "Prompt is required in persistent mode (but it can be the empty string)"
assert args.input_image_path is None, "Image should be provided directly by value in persistent mode"
assert args.trajectory in (None, 'none'), "Trajectory should be provided directly by value in persistent mode, set --trajectory=none"
assert not args.video_save_name, f"Video saving name will be set automatically for each inference request. Found string: \"{args.video_save_name}\""
def resize_intrinsics(intrinsics: np.ndarray | torch.Tensor,
old_size: tuple[int, int], new_size: tuple[int, int],
crop_size: tuple[int, int] | None = None) -> np.ndarray | torch.Tensor:
# intrinsics: (3, 3)
# old_size: (h1, w1)
# new_size: (h2, w2)
if isinstance(intrinsics, np.ndarray):
intrinsics_copy = np.copy(intrinsics)
elif isinstance(intrinsics, torch.Tensor):
intrinsics_copy = intrinsics.clone()
else:
raise ValueError(f"Invalid intrinsics type: {type(intrinsics)}")
intrinsics_copy[:, 0, :] *= new_size[1] / old_size[1]
intrinsics_copy[:, 1, :] *= new_size[0] / old_size[0]
if crop_size is not None:
intrinsics_copy[:, 0, -1] = intrinsics_copy[:, 0, -1] - (new_size[1] - crop_size[1]) / 2
intrinsics_copy[:, 1, -1] = intrinsics_copy[:, 1, -1] - (new_size[0] - crop_size[0]) / 2
return intrinsics_copy
class Gen3cPersistentModel():
"""Helper class to run Gen3C image-to-video or video-to-video inference.
This class loads the models only once and can be reused for multiple inputs.
This function handles the main video-to-world generation pipeline, including:
- Setting up the random seed for reproducibility
- Initializing the generation pipeline with the provided configuration
- Processing single or multiple prompts/images/videos from input
- Generating videos from prompts and images/videos
- Saving the generated videos and corresponding prompts to disk
Args:
cfg (argparse.Namespace): Configuration namespace containing:
- Model configuration (checkpoint paths, model settings)
- Generation parameters (guidance, steps, dimensions)
- Input/output settings (prompts/images/videos, save paths)
- Performance options (model offloading settings)
The function will save:
- Generated MP4 video files
- Text files containing the processed prompts
"""
@torch.no_grad()
def __init__(self, args: argparse.Namespace):
misc.set_random_seed(args.seed)
validate_args(args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.num_gpus > 1:
from megatron.core import parallel_state
from cosmos_predict1.utils import distributed
distributed.init()
parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus)
process_group = parallel_state.get_context_parallel_group()
self.frames_per_batch = 121
self.inference_overlap_frames = 1
# Initialize video2world generation model pipeline
pipeline = Gen3cPipeline(
inference_type="video2world",
checkpoint_dir=args.checkpoint_dir,
checkpoint_name="Gen3C-Cosmos-7B",
prompt_upsampler_dir=args.prompt_upsampler_dir,
enable_prompt_upsampler=not args.disable_prompt_upsampler,
offload_network=args.offload_diffusion_transformer,
offload_tokenizer=args.offload_tokenizer,
offload_text_encoder_model=args.offload_text_encoder_model,
offload_prompt_upsampler=args.offload_prompt_upsampler,
offload_guardrail_models=args.offload_guardrail_models,
disable_guardrail=args.disable_guardrail,
guidance=args.guidance,
num_steps=args.num_steps,
height=args.height,
width=args.width,
fps=args.fps,
num_video_frames=self.frames_per_batch,
seed=args.seed,
)
if args.num_gpus > 1:
pipeline.model.net.enable_context_parallel(process_group)
self.args = args
self.frame_buffer_max = pipeline.model.frame_buffer_max
self.generator = torch.Generator(device=device).manual_seed(args.seed)
self.sample_n_frames = pipeline.model.chunk_size
self.moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device)
self.pipeline = pipeline
self.device = device
self.device_with_rank = device_with_rank(self.device)
self.cache: Cache3D_Buffer | Cache4D | None = None
self.model_was_seeded = False
# User-provided seeding image, after pre-processing.
# Shape [B, C, T, H, W], type float, range [-1, 1].
self.seeding_image: torch.Tensor | None = None
@torch.no_grad()
def seed_model_from_values(self,
images_np: np.ndarray,
depths_np: np.ndarray | None,
world_to_cameras_np: np.ndarray,
focal_lengths_np: np.ndarray,
principal_point_rel_np: np.ndarray,
resolutions: np.ndarray,
masks_np: np.ndarray | None = None):
import torchvision.transforms.functional as transforms_F
# Check inputs
n = images_np.shape[0]
assert images_np.shape[-1] == 3
assert world_to_cameras_np.shape == (n, 4, 4)
assert focal_lengths_np.shape == (n, 2)
assert principal_point_rel_np.shape == (n, 2)
assert resolutions.shape == (n, 2)
assert (depths_np is None) or (depths_np.shape == images_np.shape[:-1])
assert (masks_np is None) or (masks_np.shape == images_np.shape[:-1])
if n == 1:
# TODO: allow user to provide depths, extrinsics and intrinsics
assert depths_np is None, "Not supported yet: directly providing pre-estimated depth values along with a single image."
# Note: image is received as 0..1 float, but MoGE expects 0..255 uint8.
input_image_np = images_np[0, ...] * 255.0
del images_np
# Predict depth and initialize 3D cache.
# Note: even though internally MoGE may use a different resolution, all of the outputs
# are properly resized & adapted to our desired (self.args.height, self.args.width) resolution,
# including the intrinsics.
(
moge_image_b1chw_float,
moge_depth_b11hw,
moge_mask_b11hw,
moge_initial_w2c_b144,
moge_intrinsics_b133,
) = _predict_moge_depth(
input_image_np, self.args.height, self.args.width, self.device_with_rank, self.moge_model
)
# TODO: MoGE provides camera params, is it okay to just ignore the user-provided ones?
input_image = moge_image_b1chw_float[:, 0].clone()
self.cache = Cache3D_Buffer(
frame_buffer_max=self.frame_buffer_max,
generator=self.generator,
noise_aug_strength=self.args.noise_aug_strength,
input_image=input_image, # [B, C, H, W]
input_depth=moge_depth_b11hw[:, 0], # [B, 1, H, W]
# input_mask=moge_mask_b11hw[:, 0], # [B, 1, H, W]
input_w2c=moge_initial_w2c_b144[:, 0], # [B, 4, 4]
input_intrinsics=moge_intrinsics_b133[:, 0], # [B, 3, 3]
filter_points_threshold=self.args.filter_points_threshold,
foreground_masking=self.args.foreground_masking,
)
seeding_image = input_image_np.transpose(2, 0, 1)[None, ...] / 128.0 - 1.0
seeding_image = torch.from_numpy(seeding_image).to(device_with_rank(self.device_with_rank))
# Return the estimated extrinsics and intrinsics in the same format as the input
estimated_w2c_b44_np = moge_initial_w2c_b144.cpu().numpy()[:, 0, ...]
moge_intrinsics_b133_np = moge_intrinsics_b133.cpu().numpy()
estimated_focal_lengths_b2_np = np.stack([moge_intrinsics_b133_np[:, 0, 0, 0],
moge_intrinsics_b133_np[:, 0, 1, 1]], axis=1)
estimated_principal_point_rel_b2_np = moge_intrinsics_b133_np[:, 0, :2, 2]
else:
if depths_np is None:
raise NotImplementedError("Seeding from multiple frames requires providing depth values.")
if masks_np is None:
raise NotImplementedError("Seeding from multiple frames requires providing mask values.")
# RGB: [B, H, W, C] to [B, C, H, W]
image_bchw_float = torch.from_numpy(images_np.transpose(0, 3, 1, 2).astype(np.float32)).to(self.device_with_rank)
# Images are received as 0..1 float32, we convert to -1..1 range.
image_bchw_float = (image_bchw_float * 2.0) - 1.0
del images_np
# Depth: [B, H, W] to [B, 1, H, W]
depth_b1hw = torch.from_numpy(depths_np[:, None, ...].astype(np.float32)).to(self.device_with_rank)
# Mask: [B, H, W] to [B, 1, H, W]
mask_b1hw = torch.from_numpy(masks_np[:, None, ...].astype(np.float32)).to(self.device_with_rank)
# World-to-camera: [B, 4, 4]
initial_w2c_b44 = torch.from_numpy(world_to_cameras_np).to(self.device_with_rank)
# Intrinsics: [B, 3, 3]
intrinsics_b33_np = np.zeros((n, 3, 3), dtype=np.float32)
intrinsics_b33_np[:, 0, 0] = focal_lengths_np[:, 0]
intrinsics_b33_np[:, 1, 1] = focal_lengths_np[:, 1]
intrinsics_b33_np[:, 0, 2] = principal_point_rel_np[:, 0] * self.args.width
intrinsics_b33_np[:, 1, 2] = principal_point_rel_np[:, 1] * self.args.height
intrinsics_b33_np[:, 2, 2] = 1.0
intrinsics_b33 = torch.from_numpy(intrinsics_b33_np).to(self.device_with_rank)
self.cache = Cache4D(
input_image=image_bchw_float.clone(), # [B, C, H, W]
input_depth=depth_b1hw, # [B, 1, H, W]
input_mask=mask_b1hw, # [B, 1, H, W]
input_w2c=initial_w2c_b44, # [B, 4, 4]
input_intrinsics=intrinsics_b33, # [B, 3, 3]
filter_points_threshold=self.args.filter_points_threshold,
foreground_masking=self.args.foreground_masking,
input_format=["F", "C", "H", "W"],
)
# Return the given extrinsics and intrinsics in the same format as the input
seeding_image = image_bchw_float
estimated_w2c_b44_np = world_to_cameras_np
estimated_focal_lengths_b2_np = focal_lengths_np
estimated_principal_point_rel_b2_np = principal_point_rel_np
# Resize seeding image to match the desired resolution.
if (seeding_image.shape[2] != self.H) or (seeding_image.shape[3] != self.W):
# TODO: would it be better to crop if aspect ratio is off?
seeding_image = transforms_F.resize(
seeding_image,
size=(self.H, self.W), # type: ignore
interpolation=transforms_F.InterpolationMode.BICUBIC,
antialias=True,
)
# Switch from [B, C, H, W] to [B, C, T, H, W].
self.seeding_image = seeding_image[:, :, None, ...]
working_resolutions_b2_np = np.tile([[self.args.width, self.args.height]], (n, 1))
return (
estimated_w2c_b44_np,
estimated_focal_lengths_b2_np,
estimated_principal_point_rel_b2_np,
working_resolutions_b2_np
)
@torch.no_grad()
def inference_on_cameras(self, view_cameras_w2cs: np.ndarray, view_camera_intrinsics: np.ndarray,
fps: int | float,
overlap_frames:int = 1,
return_estimated_depths: bool = False,
video_save_quality: int = 5,
save_buffer: bool | None = None) -> dict | None:
# TODO: this is not safe if multiple inference requests are served in parallel.
# TODO: also, it's not 100% clear whether it is correct to override this request
# after initialization of the pipeline.
self.pipeline.fps = int(fps)
del fps
save_buffer = save_buffer if (save_buffer is not None) else self.args.save_buffer
video_save_name = self.args.video_save_name
if not video_save_name:
video_save_name = f"video_{time.strftime('%Y-%m-%d_%H-%M-%S')}"
video_save_path = os.path.join(self.args.video_save_folder, f"{video_save_name}.mp4")
os.makedirs(self.args.video_save_folder, exist_ok=True)
cache_is_multiframe = isinstance(self.cache, Cache4D)
# Note: the inference server already adjusted intrinsics to match our
# inference resolution (self.W, self.H), so this call is just to make sure
# that all tensors have the right shape, etc.
view_cameras_w2cs, view_camera_intrinsics = self.prepare_camera_for_inference(
view_cameras_w2cs, view_camera_intrinsics,
old_size=(self.H, self.W), new_size=(self.H, self.W)
)
n_frames_total = view_cameras_w2cs.shape[1]
num_ar_iterations = (n_frames_total - overlap_frames) // (self.sample_n_frames - overlap_frames)
log.info(f"Generating {n_frames_total} frames will take {num_ar_iterations} auto-regressive iterations")
# Note: camera trajectory is given by the user, no need to generate it.
log.info(f"Generating frames 0 - {self.sample_n_frames} (out of {n_frames_total} total)...")
rendered_warp_images, rendered_warp_masks = self.cache.render_cache(
view_cameras_w2cs[:, 0:self.sample_n_frames],
view_camera_intrinsics[:, 0:self.sample_n_frames],
start_frame_idx=0,
)
all_rendered_warps = []
all_predicted_depth = []
if save_buffer:
all_rendered_warps.append(rendered_warp_images.clone().cpu())
current_prompt = self.args.prompt
if current_prompt is None and self.args.disable_prompt_upsampler:
log.critical("Prompt is missing, skipping world generation.")
return
# Generate video
starting_frame = self.seeding_image
if cache_is_multiframe:
starting_frame = starting_frame[0].unsqueeze(0)
generated_output = self.pipeline.generate(
prompt=current_prompt,
image_path=starting_frame,
negative_prompt=self.args.negative_prompt,
rendered_warp_images=rendered_warp_images,
rendered_warp_masks=rendered_warp_masks,
)
if generated_output is None:
log.critical("Guardrail blocked video2world generation.")
return
video, _ = generated_output
def depth_for_frame(frame: np.ndarray | torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
last_frame_hwc_0_255 = torch.tensor(frame, device=self.device_with_rank)
pred_image_for_depth_chw_0_1 = last_frame_hwc_0_255.permute(2, 0, 1) / 255.0 # (C,H,W), range [0,1]
pred_depth, pred_mask = _predict_moge_depth_from_tensor(
pred_image_for_depth_chw_0_1, self.moge_model
)
return pred_depth, pred_mask, pred_image_for_depth_chw_0_1
# We predict depth either if we need it (multi-round generation without depth in the cache),
# or if the user requested it explicitly.
need_depth_of_latest_frame = return_estimated_depths or (num_ar_iterations > 1 and not cache_is_multiframe)
if need_depth_of_latest_frame:
pred_depth, _, pred_image_for_depth_chw_0_1 = depth_for_frame(video[-1])
if return_estimated_depths:
# For easier indexing, we include entries even for the frames for which we don't predict
# depth. Since the results will be transmitted in compressed format, this hopefully
# shouldn't take up any additional bandwidth.
depths_batch_0 = np.full((video.shape[0], 1, self.H, self.W), fill_value=np.nan,
dtype=np.float32)
depths_batch_0[-1, ...] = pred_depth.cpu().numpy()
all_predicted_depth.append(depths_batch_0)
del depths_batch_0
# Autoregressive generation (if needed)
for num_iter in range(1, num_ar_iterations):
# Overlap by `overlap_frames` frames
start_frame_idx = num_iter * (self.sample_n_frames - overlap_frames)
end_frame_idx = start_frame_idx + self.sample_n_frames
log.info(f"Generating frames {start_frame_idx} - {end_frame_idx} (out of {n_frames_total} total)...")
if cache_is_multiframe:
# Nothing much to do, we assume that depth is alraedy provided and
# all frames of the seeding video are already in the cache.
pred_image_for_depth_chw_0_1 = torch.tensor(
video[-1], device=self.device_with_rank
).permute(2, 0, 1) / 255.0 # (C,H,W), range [0,1]
else:
self.cache.update_cache(
new_image=pred_image_for_depth_chw_0_1.unsqueeze(0) * 2 - 1, # (B,C,H,W) range [-1,1]
new_depth=pred_depth, # (1,1,H,W)
# new_mask=pred_mask, # (1,1,H,W)
new_w2c=view_cameras_w2cs[:, start_frame_idx],
new_intrinsics=view_camera_intrinsics[:, start_frame_idx],
)
current_segment_w2cs = view_cameras_w2cs[:, start_frame_idx:end_frame_idx]
current_segment_intrinsics = view_camera_intrinsics[:, start_frame_idx:end_frame_idx]
cache_start_frame_idx = 0
if cache_is_multiframe:
# If requesting more frames than are available in the cache,
# freeze (hold) on the last batch of frames.
cache_start_frame_idx = min(
start_frame_idx,
self.cache.input_frame_count() - (end_frame_idx - start_frame_idx)
)
rendered_warp_images, rendered_warp_masks = self.cache.render_cache(
current_segment_w2cs,
current_segment_intrinsics,
start_frame_idx=cache_start_frame_idx,
)
if save_buffer:
all_rendered_warps.append(rendered_warp_images[:, overlap_frames:].clone().cpu())
pred_image_for_depth_bcthw_minus1_1 = pred_image_for_depth_chw_0_1.unsqueeze(0).unsqueeze(2) * 2 - 1 # (B,C,T,H,W), range [-1,1]
generated_output = self.pipeline.generate(
prompt=current_prompt,
image_path=pred_image_for_depth_bcthw_minus1_1,
negative_prompt=self.args.negative_prompt,
rendered_warp_images=rendered_warp_images,
rendered_warp_masks=rendered_warp_masks,
)
video_new, _ = generated_output
video = np.concatenate([video, video_new[overlap_frames:]], axis=0)
# Prepare depth prediction for the next AR iteration.
need_depth_of_latest_frame = return_estimated_depths or ((num_iter < num_ar_iterations - 1) and not cache_is_multiframe)
if need_depth_of_latest_frame:
# Either we don't have depth (e.g. single-image seeding), or the user requested
# depth to be returned explicitly.
pred_depth, _, pred_image_for_depth_chw_0_1 = depth_for_frame(video_new[-1])
if return_estimated_depths:
depths_batch_i = np.full((video_new.shape[0] - overlap_frames, 1, self.H, self.W),
fill_value=np.nan, dtype=np.float32)
depths_batch_i[-1, ...] = pred_depth.cpu().numpy()
all_predicted_depth.append(depths_batch_i)
del depths_batch_i
if is_rank0():
# Final video processing
final_video_to_save = video
final_width = self.args.width
if save_buffer and all_rendered_warps:
squeezed_warps = [t.squeeze(0) for t in all_rendered_warps] # Each is (T_chunk, n_i, C, H, W)
if squeezed_warps:
n_max = max(t.shape[1] for t in squeezed_warps)
padded_t_list = []
for sq_t in squeezed_warps:
# sq_t shape: (T_chunk, n_i, C, H, W)
current_n_i = sq_t.shape[1]
padding_needed_dim1 = n_max - current_n_i
pad_spec = (0,0, # W
0,0, # H
0,0, # C
0,padding_needed_dim1, # n_i
0,0) # T_chunk
padded_t = F.pad(sq_t, pad_spec, mode='constant', value=-1.0)
padded_t_list.append(padded_t)
full_rendered_warp_tensor = torch.cat(padded_t_list, dim=0)
T_total, _, C_dim, H_dim, W_dim = full_rendered_warp_tensor.shape
buffer_video_TCHnW = full_rendered_warp_tensor.permute(0, 2, 3, 1, 4)
buffer_video_TCHWstacked = buffer_video_TCHnW.contiguous().view(T_total, C_dim, H_dim, n_max * W_dim)
buffer_video_TCHWstacked = (buffer_video_TCHWstacked * 0.5 + 0.5) * 255.0
buffer_numpy_TCHWstacked = buffer_video_TCHWstacked.cpu().numpy().astype(np.uint8)
buffer_numpy_THWC = np.transpose(buffer_numpy_TCHWstacked, (0, 2, 3, 1))
final_video_to_save = np.concatenate([buffer_numpy_THWC, final_video_to_save], axis=2)
final_width = self.args.width * (1 + n_max)
log.info(f"Concatenating video with {n_max} warp buffers. Final video width will be {final_width}")
else:
log.info("No warp buffers to save.")
# Save video
save_video(
video=final_video_to_save,
fps=self.pipeline.fps,
H=self.args.height,
W=final_width,
video_save_quality=video_save_quality,
video_save_path=video_save_path,
)
log.info(f"Saved video to {video_save_path}")
if return_estimated_depths:
predicted_depth = np.concatenate(all_predicted_depth, axis=0)
else:
predicted_depth = None
# Currently `video` is [n_frames, height, width, channels].
# Return as [1, n_frames, channels, height, width] for consistency with other codebases.
video = video.transpose(0, 3, 1, 2)[None, ...]
# Depth is returned as [n_frames, channels, height, width].
# TODO: handle overlap
rendered_warp_images_no_overlap = rendered_warp_images
video_no_overlap = video
return {
"rendered_warp_images": rendered_warp_images,
"video": video,
"rendered_warp_images_no_overlap": rendered_warp_images_no_overlap,
"video_no_overlap": video_no_overlap,
"predicted_depth": predicted_depth,
"video_save_path": video_save_path,
}
# --------------------
def prepare_camera_for_inference(self, view_cameras: np.ndarray, view_camera_intrinsics: np.ndarray,
old_size: tuple[int, int], new_size: tuple[int, int]):
"""Old and new sizes should be given as (height, width)."""
if isinstance(view_cameras, np.ndarray):
view_cameras = torch.from_numpy(view_cameras).float().contiguous()
if view_cameras.ndim == 3:
view_cameras = view_cameras.unsqueeze(dim=0)
if isinstance(view_camera_intrinsics, np.ndarray):
view_camera_intrinsics = torch.from_numpy(view_camera_intrinsics).float().contiguous()
view_camera_intrinsics = resize_intrinsics(view_camera_intrinsics, old_size, new_size)
view_camera_intrinsics = view_camera_intrinsics.unsqueeze(dim=0)
assert view_camera_intrinsics.ndim == 4
return view_cameras.to(device_with_rank(self.device_with_rank)), \
view_camera_intrinsics.to(device_with_rank(self.device_with_rank))
def get_cache_input_depths(self) -> torch.Tensor | None:
if self.cache is None:
return None
return self.cache.input_depth
@property
def W(self) -> int:
return self.args.width
@property
def H(self) -> int:
return self.args.height
def clear_cache(self) -> None:
self.cache = None
self.model_was_seeded = False
def cleanup(self) -> None:
if self.args.num_gpus > 1:
rank = get_rank()
log.info(f"Model cleanup: destroying model parallel group on rank={rank}.",
rank0_only=False)
from megatron.core import parallel_state
parallel_state.destroy_model_parallel()
import torch.distributed as dist
dist.destroy_process_group()
log.info(f"Destroyed model parallel group on rank={rank}.", rank0_only=False)
else:
log.info("Model cleanup: nothing to do (no parallelism).", rank0_only=False)
|