Spaces:
Build error
Build error
import argparse | |
import os | |
import time | |
from moge.model.v1 import MoGeModel | |
import torch | |
import numpy as np | |
from cosmos_predict1.diffusion.inference.gen3c_pipeline import Gen3cPipeline | |
from cosmos_predict1.diffusion.inference.gen3c_single_image import ( | |
create_parser as create_parser_base, | |
validate_args as validate_args_base, | |
_predict_moge_depth, | |
_predict_moge_depth_from_tensor | |
) | |
from cosmos_predict1.utils import log, misc | |
from cosmos_predict1.utils.distributed import device_with_rank, is_rank0, get_rank | |
from cosmos_predict1.utils.io import save_video | |
from cosmos_predict1.diffusion.inference.cache_3d import Cache3D_Buffer, Cache4D | |
import torch.nn.functional as F | |
def create_parser(): | |
return create_parser_base() | |
def validate_args(args: argparse.Namespace): | |
validate_args_base(args) | |
assert args.batch_input_path is None, "Unsupported in persistent mode" | |
assert args.prompt is not None, "Prompt is required in persistent mode (but it can be the empty string)" | |
assert args.input_image_path is None, "Image should be provided directly by value in persistent mode" | |
assert args.trajectory in (None, 'none'), "Trajectory should be provided directly by value in persistent mode, set --trajectory=none" | |
assert not args.video_save_name, f"Video saving name will be set automatically for each inference request. Found string: \"{args.video_save_name}\"" | |
def resize_intrinsics(intrinsics: np.ndarray | torch.Tensor, | |
old_size: tuple[int, int], new_size: tuple[int, int], | |
crop_size: tuple[int, int] | None = None) -> np.ndarray | torch.Tensor: | |
# intrinsics: (3, 3) | |
# old_size: (h1, w1) | |
# new_size: (h2, w2) | |
if isinstance(intrinsics, np.ndarray): | |
intrinsics_copy = np.copy(intrinsics) | |
elif isinstance(intrinsics, torch.Tensor): | |
intrinsics_copy = intrinsics.clone() | |
else: | |
raise ValueError(f"Invalid intrinsics type: {type(intrinsics)}") | |
intrinsics_copy[:, 0, :] *= new_size[1] / old_size[1] | |
intrinsics_copy[:, 1, :] *= new_size[0] / old_size[0] | |
if crop_size is not None: | |
intrinsics_copy[:, 0, -1] = intrinsics_copy[:, 0, -1] - (new_size[1] - crop_size[1]) / 2 | |
intrinsics_copy[:, 1, -1] = intrinsics_copy[:, 1, -1] - (new_size[0] - crop_size[0]) / 2 | |
return intrinsics_copy | |
class Gen3cPersistentModel(): | |
"""Helper class to run Gen3C image-to-video or video-to-video inference. | |
This class loads the models only once and can be reused for multiple inputs. | |
This function handles the main video-to-world generation pipeline, including: | |
- Setting up the random seed for reproducibility | |
- Initializing the generation pipeline with the provided configuration | |
- Processing single or multiple prompts/images/videos from input | |
- Generating videos from prompts and images/videos | |
- Saving the generated videos and corresponding prompts to disk | |
Args: | |
cfg (argparse.Namespace): Configuration namespace containing: | |
- Model configuration (checkpoint paths, model settings) | |
- Generation parameters (guidance, steps, dimensions) | |
- Input/output settings (prompts/images/videos, save paths) | |
- Performance options (model offloading settings) | |
The function will save: | |
- Generated MP4 video files | |
- Text files containing the processed prompts | |
""" | |
def __init__(self, args: argparse.Namespace): | |
misc.set_random_seed(args.seed) | |
validate_args(args) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if args.num_gpus > 1: | |
from megatron.core import parallel_state | |
from cosmos_predict1.utils import distributed | |
distributed.init() | |
parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus) | |
process_group = parallel_state.get_context_parallel_group() | |
self.frames_per_batch = 121 | |
self.inference_overlap_frames = 1 | |
# Initialize video2world generation model pipeline | |
pipeline = Gen3cPipeline( | |
inference_type="video2world", | |
checkpoint_dir=args.checkpoint_dir, | |
checkpoint_name="Gen3C-Cosmos-7B", | |
prompt_upsampler_dir=args.prompt_upsampler_dir, | |
enable_prompt_upsampler=not args.disable_prompt_upsampler, | |
offload_network=args.offload_diffusion_transformer, | |
offload_tokenizer=args.offload_tokenizer, | |
offload_text_encoder_model=args.offload_text_encoder_model, | |
offload_prompt_upsampler=args.offload_prompt_upsampler, | |
offload_guardrail_models=args.offload_guardrail_models, | |
disable_guardrail=args.disable_guardrail, | |
guidance=args.guidance, | |
num_steps=args.num_steps, | |
height=args.height, | |
width=args.width, | |
fps=args.fps, | |
num_video_frames=self.frames_per_batch, | |
seed=args.seed, | |
) | |
if args.num_gpus > 1: | |
pipeline.model.net.enable_context_parallel(process_group) | |
self.args = args | |
self.frame_buffer_max = pipeline.model.frame_buffer_max | |
self.generator = torch.Generator(device=device).manual_seed(args.seed) | |
self.sample_n_frames = pipeline.model.chunk_size | |
self.moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device) | |
self.pipeline = pipeline | |
self.device = device | |
self.device_with_rank = device_with_rank(self.device) | |
self.cache: Cache3D_Buffer | Cache4D | None = None | |
self.model_was_seeded = False | |
# User-provided seeding image, after pre-processing. | |
# Shape [B, C, T, H, W], type float, range [-1, 1]. | |
self.seeding_image: torch.Tensor | None = None | |
def seed_model_from_values(self, | |
images_np: np.ndarray, | |
depths_np: np.ndarray | None, | |
world_to_cameras_np: np.ndarray, | |
focal_lengths_np: np.ndarray, | |
principal_point_rel_np: np.ndarray, | |
resolutions: np.ndarray, | |
masks_np: np.ndarray | None = None): | |
import torchvision.transforms.functional as transforms_F | |
# Check inputs | |
n = images_np.shape[0] | |
assert images_np.shape[-1] == 3 | |
assert world_to_cameras_np.shape == (n, 4, 4) | |
assert focal_lengths_np.shape == (n, 2) | |
assert principal_point_rel_np.shape == (n, 2) | |
assert resolutions.shape == (n, 2) | |
assert (depths_np is None) or (depths_np.shape == images_np.shape[:-1]) | |
assert (masks_np is None) or (masks_np.shape == images_np.shape[:-1]) | |
if n == 1: | |
# TODO: allow user to provide depths, extrinsics and intrinsics | |
assert depths_np is None, "Not supported yet: directly providing pre-estimated depth values along with a single image." | |
# Note: image is received as 0..1 float, but MoGE expects 0..255 uint8. | |
input_image_np = images_np[0, ...] * 255.0 | |
del images_np | |
# Predict depth and initialize 3D cache. | |
# Note: even though internally MoGE may use a different resolution, all of the outputs | |
# are properly resized & adapted to our desired (self.args.height, self.args.width) resolution, | |
# including the intrinsics. | |
( | |
moge_image_b1chw_float, | |
moge_depth_b11hw, | |
moge_mask_b11hw, | |
moge_initial_w2c_b144, | |
moge_intrinsics_b133, | |
) = _predict_moge_depth( | |
input_image_np, self.args.height, self.args.width, self.device_with_rank, self.moge_model | |
) | |
# TODO: MoGE provides camera params, is it okay to just ignore the user-provided ones? | |
input_image = moge_image_b1chw_float[:, 0].clone() | |
self.cache = Cache3D_Buffer( | |
frame_buffer_max=self.frame_buffer_max, | |
generator=self.generator, | |
noise_aug_strength=self.args.noise_aug_strength, | |
input_image=input_image, # [B, C, H, W] | |
input_depth=moge_depth_b11hw[:, 0], # [B, 1, H, W] | |
# input_mask=moge_mask_b11hw[:, 0], # [B, 1, H, W] | |
input_w2c=moge_initial_w2c_b144[:, 0], # [B, 4, 4] | |
input_intrinsics=moge_intrinsics_b133[:, 0], # [B, 3, 3] | |
filter_points_threshold=self.args.filter_points_threshold, | |
foreground_masking=self.args.foreground_masking, | |
) | |
seeding_image = input_image_np.transpose(2, 0, 1)[None, ...] / 128.0 - 1.0 | |
seeding_image = torch.from_numpy(seeding_image).to(device_with_rank(self.device_with_rank)) | |
# Return the estimated extrinsics and intrinsics in the same format as the input | |
estimated_w2c_b44_np = moge_initial_w2c_b144.cpu().numpy()[:, 0, ...] | |
moge_intrinsics_b133_np = moge_intrinsics_b133.cpu().numpy() | |
estimated_focal_lengths_b2_np = np.stack([moge_intrinsics_b133_np[:, 0, 0, 0], | |
moge_intrinsics_b133_np[:, 0, 1, 1]], axis=1) | |
estimated_principal_point_rel_b2_np = moge_intrinsics_b133_np[:, 0, :2, 2] | |
else: | |
if depths_np is None: | |
raise NotImplementedError("Seeding from multiple frames requires providing depth values.") | |
if masks_np is None: | |
raise NotImplementedError("Seeding from multiple frames requires providing mask values.") | |
# RGB: [B, H, W, C] to [B, C, H, W] | |
image_bchw_float = torch.from_numpy(images_np.transpose(0, 3, 1, 2).astype(np.float32)).to(self.device_with_rank) | |
# Images are received as 0..1 float32, we convert to -1..1 range. | |
image_bchw_float = (image_bchw_float * 2.0) - 1.0 | |
del images_np | |
# Depth: [B, H, W] to [B, 1, H, W] | |
depth_b1hw = torch.from_numpy(depths_np[:, None, ...].astype(np.float32)).to(self.device_with_rank) | |
# Mask: [B, H, W] to [B, 1, H, W] | |
mask_b1hw = torch.from_numpy(masks_np[:, None, ...].astype(np.float32)).to(self.device_with_rank) | |
# World-to-camera: [B, 4, 4] | |
initial_w2c_b44 = torch.from_numpy(world_to_cameras_np).to(self.device_with_rank) | |
# Intrinsics: [B, 3, 3] | |
intrinsics_b33_np = np.zeros((n, 3, 3), dtype=np.float32) | |
intrinsics_b33_np[:, 0, 0] = focal_lengths_np[:, 0] | |
intrinsics_b33_np[:, 1, 1] = focal_lengths_np[:, 1] | |
intrinsics_b33_np[:, 0, 2] = principal_point_rel_np[:, 0] * self.args.width | |
intrinsics_b33_np[:, 1, 2] = principal_point_rel_np[:, 1] * self.args.height | |
intrinsics_b33_np[:, 2, 2] = 1.0 | |
intrinsics_b33 = torch.from_numpy(intrinsics_b33_np).to(self.device_with_rank) | |
self.cache = Cache4D( | |
input_image=image_bchw_float.clone(), # [B, C, H, W] | |
input_depth=depth_b1hw, # [B, 1, H, W] | |
input_mask=mask_b1hw, # [B, 1, H, W] | |
input_w2c=initial_w2c_b44, # [B, 4, 4] | |
input_intrinsics=intrinsics_b33, # [B, 3, 3] | |
filter_points_threshold=self.args.filter_points_threshold, | |
foreground_masking=self.args.foreground_masking, | |
input_format=["F", "C", "H", "W"], | |
) | |
# Return the given extrinsics and intrinsics in the same format as the input | |
seeding_image = image_bchw_float | |
estimated_w2c_b44_np = world_to_cameras_np | |
estimated_focal_lengths_b2_np = focal_lengths_np | |
estimated_principal_point_rel_b2_np = principal_point_rel_np | |
# Resize seeding image to match the desired resolution. | |
if (seeding_image.shape[2] != self.H) or (seeding_image.shape[3] != self.W): | |
# TODO: would it be better to crop if aspect ratio is off? | |
seeding_image = transforms_F.resize( | |
seeding_image, | |
size=(self.H, self.W), # type: ignore | |
interpolation=transforms_F.InterpolationMode.BICUBIC, | |
antialias=True, | |
) | |
# Switch from [B, C, H, W] to [B, C, T, H, W]. | |
self.seeding_image = seeding_image[:, :, None, ...] | |
working_resolutions_b2_np = np.tile([[self.args.width, self.args.height]], (n, 1)) | |
return ( | |
estimated_w2c_b44_np, | |
estimated_focal_lengths_b2_np, | |
estimated_principal_point_rel_b2_np, | |
working_resolutions_b2_np | |
) | |
def inference_on_cameras(self, view_cameras_w2cs: np.ndarray, view_camera_intrinsics: np.ndarray, | |
fps: int | float, | |
overlap_frames:int = 1, | |
return_estimated_depths: bool = False, | |
video_save_quality: int = 5, | |
save_buffer: bool | None = None) -> dict | None: | |
# TODO: this is not safe if multiple inference requests are served in parallel. | |
# TODO: also, it's not 100% clear whether it is correct to override this request | |
# after initialization of the pipeline. | |
self.pipeline.fps = int(fps) | |
del fps | |
save_buffer = save_buffer if (save_buffer is not None) else self.args.save_buffer | |
video_save_name = self.args.video_save_name | |
if not video_save_name: | |
video_save_name = f"video_{time.strftime('%Y-%m-%d_%H-%M-%S')}" | |
video_save_path = os.path.join(self.args.video_save_folder, f"{video_save_name}.mp4") | |
os.makedirs(self.args.video_save_folder, exist_ok=True) | |
cache_is_multiframe = isinstance(self.cache, Cache4D) | |
# Note: the inference server already adjusted intrinsics to match our | |
# inference resolution (self.W, self.H), so this call is just to make sure | |
# that all tensors have the right shape, etc. | |
view_cameras_w2cs, view_camera_intrinsics = self.prepare_camera_for_inference( | |
view_cameras_w2cs, view_camera_intrinsics, | |
old_size=(self.H, self.W), new_size=(self.H, self.W) | |
) | |
n_frames_total = view_cameras_w2cs.shape[1] | |
num_ar_iterations = (n_frames_total - overlap_frames) // (self.sample_n_frames - overlap_frames) | |
log.info(f"Generating {n_frames_total} frames will take {num_ar_iterations} auto-regressive iterations") | |
# Note: camera trajectory is given by the user, no need to generate it. | |
log.info(f"Generating frames 0 - {self.sample_n_frames} (out of {n_frames_total} total)...") | |
rendered_warp_images, rendered_warp_masks = self.cache.render_cache( | |
view_cameras_w2cs[:, 0:self.sample_n_frames], | |
view_camera_intrinsics[:, 0:self.sample_n_frames], | |
start_frame_idx=0, | |
) | |
all_rendered_warps = [] | |
all_predicted_depth = [] | |
if save_buffer: | |
all_rendered_warps.append(rendered_warp_images.clone().cpu()) | |
current_prompt = self.args.prompt | |
if current_prompt is None and self.args.disable_prompt_upsampler: | |
log.critical("Prompt is missing, skipping world generation.") | |
return | |
# Generate video | |
starting_frame = self.seeding_image | |
if cache_is_multiframe: | |
starting_frame = starting_frame[0].unsqueeze(0) | |
generated_output = self.pipeline.generate( | |
prompt=current_prompt, | |
image_path=starting_frame, | |
negative_prompt=self.args.negative_prompt, | |
rendered_warp_images=rendered_warp_images, | |
rendered_warp_masks=rendered_warp_masks, | |
) | |
if generated_output is None: | |
log.critical("Guardrail blocked video2world generation.") | |
return | |
video, _ = generated_output | |
def depth_for_frame(frame: np.ndarray | torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
last_frame_hwc_0_255 = torch.tensor(frame, device=self.device_with_rank) | |
pred_image_for_depth_chw_0_1 = last_frame_hwc_0_255.permute(2, 0, 1) / 255.0 # (C,H,W), range [0,1] | |
pred_depth, pred_mask = _predict_moge_depth_from_tensor( | |
pred_image_for_depth_chw_0_1, self.moge_model | |
) | |
return pred_depth, pred_mask, pred_image_for_depth_chw_0_1 | |
# We predict depth either if we need it (multi-round generation without depth in the cache), | |
# or if the user requested it explicitly. | |
need_depth_of_latest_frame = return_estimated_depths or (num_ar_iterations > 1 and not cache_is_multiframe) | |
if need_depth_of_latest_frame: | |
pred_depth, _, pred_image_for_depth_chw_0_1 = depth_for_frame(video[-1]) | |
if return_estimated_depths: | |
# For easier indexing, we include entries even for the frames for which we don't predict | |
# depth. Since the results will be transmitted in compressed format, this hopefully | |
# shouldn't take up any additional bandwidth. | |
depths_batch_0 = np.full((video.shape[0], 1, self.H, self.W), fill_value=np.nan, | |
dtype=np.float32) | |
depths_batch_0[-1, ...] = pred_depth.cpu().numpy() | |
all_predicted_depth.append(depths_batch_0) | |
del depths_batch_0 | |
# Autoregressive generation (if needed) | |
for num_iter in range(1, num_ar_iterations): | |
# Overlap by `overlap_frames` frames | |
start_frame_idx = num_iter * (self.sample_n_frames - overlap_frames) | |
end_frame_idx = start_frame_idx + self.sample_n_frames | |
log.info(f"Generating frames {start_frame_idx} - {end_frame_idx} (out of {n_frames_total} total)...") | |
if cache_is_multiframe: | |
# Nothing much to do, we assume that depth is alraedy provided and | |
# all frames of the seeding video are already in the cache. | |
pred_image_for_depth_chw_0_1 = torch.tensor( | |
video[-1], device=self.device_with_rank | |
).permute(2, 0, 1) / 255.0 # (C,H,W), range [0,1] | |
else: | |
self.cache.update_cache( | |
new_image=pred_image_for_depth_chw_0_1.unsqueeze(0) * 2 - 1, # (B,C,H,W) range [-1,1] | |
new_depth=pred_depth, # (1,1,H,W) | |
# new_mask=pred_mask, # (1,1,H,W) | |
new_w2c=view_cameras_w2cs[:, start_frame_idx], | |
new_intrinsics=view_camera_intrinsics[:, start_frame_idx], | |
) | |
current_segment_w2cs = view_cameras_w2cs[:, start_frame_idx:end_frame_idx] | |
current_segment_intrinsics = view_camera_intrinsics[:, start_frame_idx:end_frame_idx] | |
cache_start_frame_idx = 0 | |
if cache_is_multiframe: | |
# If requesting more frames than are available in the cache, | |
# freeze (hold) on the last batch of frames. | |
cache_start_frame_idx = min( | |
start_frame_idx, | |
self.cache.input_frame_count() - (end_frame_idx - start_frame_idx) | |
) | |
rendered_warp_images, rendered_warp_masks = self.cache.render_cache( | |
current_segment_w2cs, | |
current_segment_intrinsics, | |
start_frame_idx=cache_start_frame_idx, | |
) | |
if save_buffer: | |
all_rendered_warps.append(rendered_warp_images[:, overlap_frames:].clone().cpu()) | |
pred_image_for_depth_bcthw_minus1_1 = pred_image_for_depth_chw_0_1.unsqueeze(0).unsqueeze(2) * 2 - 1 # (B,C,T,H,W), range [-1,1] | |
generated_output = self.pipeline.generate( | |
prompt=current_prompt, | |
image_path=pred_image_for_depth_bcthw_minus1_1, | |
negative_prompt=self.args.negative_prompt, | |
rendered_warp_images=rendered_warp_images, | |
rendered_warp_masks=rendered_warp_masks, | |
) | |
video_new, _ = generated_output | |
video = np.concatenate([video, video_new[overlap_frames:]], axis=0) | |
# Prepare depth prediction for the next AR iteration. | |
need_depth_of_latest_frame = return_estimated_depths or ((num_iter < num_ar_iterations - 1) and not cache_is_multiframe) | |
if need_depth_of_latest_frame: | |
# Either we don't have depth (e.g. single-image seeding), or the user requested | |
# depth to be returned explicitly. | |
pred_depth, _, pred_image_for_depth_chw_0_1 = depth_for_frame(video_new[-1]) | |
if return_estimated_depths: | |
depths_batch_i = np.full((video_new.shape[0] - overlap_frames, 1, self.H, self.W), | |
fill_value=np.nan, dtype=np.float32) | |
depths_batch_i[-1, ...] = pred_depth.cpu().numpy() | |
all_predicted_depth.append(depths_batch_i) | |
del depths_batch_i | |
if is_rank0(): | |
# Final video processing | |
final_video_to_save = video | |
final_width = self.args.width | |
if save_buffer and all_rendered_warps: | |
squeezed_warps = [t.squeeze(0) for t in all_rendered_warps] # Each is (T_chunk, n_i, C, H, W) | |
if squeezed_warps: | |
n_max = max(t.shape[1] for t in squeezed_warps) | |
padded_t_list = [] | |
for sq_t in squeezed_warps: | |
# sq_t shape: (T_chunk, n_i, C, H, W) | |
current_n_i = sq_t.shape[1] | |
padding_needed_dim1 = n_max - current_n_i | |
pad_spec = (0,0, # W | |
0,0, # H | |
0,0, # C | |
0,padding_needed_dim1, # n_i | |
0,0) # T_chunk | |
padded_t = F.pad(sq_t, pad_spec, mode='constant', value=-1.0) | |
padded_t_list.append(padded_t) | |
full_rendered_warp_tensor = torch.cat(padded_t_list, dim=0) | |
T_total, _, C_dim, H_dim, W_dim = full_rendered_warp_tensor.shape | |
buffer_video_TCHnW = full_rendered_warp_tensor.permute(0, 2, 3, 1, 4) | |
buffer_video_TCHWstacked = buffer_video_TCHnW.contiguous().view(T_total, C_dim, H_dim, n_max * W_dim) | |
buffer_video_TCHWstacked = (buffer_video_TCHWstacked * 0.5 + 0.5) * 255.0 | |
buffer_numpy_TCHWstacked = buffer_video_TCHWstacked.cpu().numpy().astype(np.uint8) | |
buffer_numpy_THWC = np.transpose(buffer_numpy_TCHWstacked, (0, 2, 3, 1)) | |
final_video_to_save = np.concatenate([buffer_numpy_THWC, final_video_to_save], axis=2) | |
final_width = self.args.width * (1 + n_max) | |
log.info(f"Concatenating video with {n_max} warp buffers. Final video width will be {final_width}") | |
else: | |
log.info("No warp buffers to save.") | |
# Save video | |
save_video( | |
video=final_video_to_save, | |
fps=self.pipeline.fps, | |
H=self.args.height, | |
W=final_width, | |
video_save_quality=video_save_quality, | |
video_save_path=video_save_path, | |
) | |
log.info(f"Saved video to {video_save_path}") | |
if return_estimated_depths: | |
predicted_depth = np.concatenate(all_predicted_depth, axis=0) | |
else: | |
predicted_depth = None | |
# Currently `video` is [n_frames, height, width, channels]. | |
# Return as [1, n_frames, channels, height, width] for consistency with other codebases. | |
video = video.transpose(0, 3, 1, 2)[None, ...] | |
# Depth is returned as [n_frames, channels, height, width]. | |
# TODO: handle overlap | |
rendered_warp_images_no_overlap = rendered_warp_images | |
video_no_overlap = video | |
return { | |
"rendered_warp_images": rendered_warp_images, | |
"video": video, | |
"rendered_warp_images_no_overlap": rendered_warp_images_no_overlap, | |
"video_no_overlap": video_no_overlap, | |
"predicted_depth": predicted_depth, | |
"video_save_path": video_save_path, | |
} | |
# -------------------- | |
def prepare_camera_for_inference(self, view_cameras: np.ndarray, view_camera_intrinsics: np.ndarray, | |
old_size: tuple[int, int], new_size: tuple[int, int]): | |
"""Old and new sizes should be given as (height, width).""" | |
if isinstance(view_cameras, np.ndarray): | |
view_cameras = torch.from_numpy(view_cameras).float().contiguous() | |
if view_cameras.ndim == 3: | |
view_cameras = view_cameras.unsqueeze(dim=0) | |
if isinstance(view_camera_intrinsics, np.ndarray): | |
view_camera_intrinsics = torch.from_numpy(view_camera_intrinsics).float().contiguous() | |
view_camera_intrinsics = resize_intrinsics(view_camera_intrinsics, old_size, new_size) | |
view_camera_intrinsics = view_camera_intrinsics.unsqueeze(dim=0) | |
assert view_camera_intrinsics.ndim == 4 | |
return view_cameras.to(device_with_rank(self.device_with_rank)), \ | |
view_camera_intrinsics.to(device_with_rank(self.device_with_rank)) | |
def get_cache_input_depths(self) -> torch.Tensor | None: | |
if self.cache is None: | |
return None | |
return self.cache.input_depth | |
def W(self) -> int: | |
return self.args.width | |
def H(self) -> int: | |
return self.args.height | |
def clear_cache(self) -> None: | |
self.cache = None | |
self.model_was_seeded = False | |
def cleanup(self) -> None: | |
if self.args.num_gpus > 1: | |
rank = get_rank() | |
log.info(f"Model cleanup: destroying model parallel group on rank={rank}.", | |
rank0_only=False) | |
from megatron.core import parallel_state | |
parallel_state.destroy_model_parallel() | |
import torch.distributed as dist | |
dist.destroy_process_group() | |
log.info(f"Destroyed model parallel group on rank={rank}.", rank0_only=False) | |
else: | |
log.info("Model cleanup: nothing to do (no parallelism).", rank0_only=False) | |