|
import os |
|
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' |
|
from pathlib import Path |
|
import sys |
|
if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: |
|
sys.path.insert(0, _package_root) |
|
from typing import * |
|
import itertools |
|
import json |
|
import warnings |
|
|
|
import click |
|
|
|
|
|
@click.command(help='Inference script for panorama images') |
|
@click.option('--input', '-i', 'input_path', type=click.Path(exists=True), required=True, help='Input image or folder path. "jpg" and "png" are supported.') |
|
@click.option('--output', '-o', 'output_path', type=click.Path(), default='./output', help='Output folder path') |
|
@click.option('--pretrained', 'pretrained_model_name_or_path', type=str, default='Ruicheng/moge-vitl', help='Pretrained model name or path. Defaults to "Ruicheng/moge-vitl"') |
|
@click.option('--device', 'device_name', type=str, default='cuda', help='Device name (e.g. "cuda", "cuda:0", "cpu"). Defaults to "cuda"') |
|
@click.option('--resize', 'resize_to', type=int, default=None, help='Resize the image(s) & output maps to a specific size. Defaults to None (no resizing).') |
|
@click.option('--resolution_level', type=int, default=9, help='An integer [0-9] for the resolution level of inference. The higher, the better but slower. Defaults to 9. Note that it is irrelevant to the output resolution.') |
|
@click.option('--threshold', type=float, default=0.03, help='Threshold for removing edges. Defaults to 0.03. Smaller value removes more edges. "inf" means no thresholding.') |
|
@click.option('--batch_size', type=int, default=4, help='Batch size for inference. Defaults to 4.') |
|
@click.option('--splitted', 'save_splitted', is_flag=True, help='Whether to save the splitted images. Defaults to False.') |
|
@click.option('--maps', 'save_maps_', is_flag=True, help='Whether to save the output maps and fov(image, depth, mask, points, fov).') |
|
@click.option('--glb', 'save_glb_', is_flag=True, help='Whether to save the output as a.glb file. The color will be saved as a texture.') |
|
@click.option('--ply', 'save_ply_', is_flag=True, help='Whether to save the output as a.ply file. The color will be saved as vertex colors.') |
|
@click.option('--show', 'show', is_flag=True, help='Whether show the output in a window. Note that this requires pyglet<2 installed as required by trimesh.') |
|
def main( |
|
input_path: str, |
|
output_path: str, |
|
pretrained_model_name_or_path: str, |
|
device_name: str, |
|
resize_to: int, |
|
resolution_level: int, |
|
threshold: float, |
|
batch_size: int, |
|
save_splitted: bool, |
|
save_maps_: bool, |
|
save_glb_: bool, |
|
save_ply_: bool, |
|
show: bool, |
|
): |
|
|
|
import cv2 |
|
import numpy as np |
|
from numpy import ndarray |
|
import torch |
|
from PIL import Image |
|
from tqdm import tqdm, trange |
|
import trimesh |
|
import trimesh.visual |
|
from scipy.sparse import csr_array, hstack, vstack |
|
from scipy.ndimage import convolve |
|
from scipy.sparse.linalg import lsmr |
|
|
|
import utils3d |
|
from moge.model.v1 import MoGeModel |
|
from moge.utils.io import save_glb, save_ply |
|
from moge.utils.vis import colorize_depth |
|
from moge.utils.panorama import spherical_uv_to_directions, get_panorama_cameras, split_panorama_image, merge_panorama_depth |
|
|
|
|
|
device = torch.device(device_name) |
|
|
|
include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG'] |
|
if Path(input_path).is_dir(): |
|
image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices))) |
|
else: |
|
image_paths = [Path(input_path)] |
|
|
|
if len(image_paths) == 0: |
|
raise FileNotFoundError(f'No image files found in {input_path}') |
|
|
|
|
|
if not any([save_maps_, save_glb_, save_ply_]): |
|
warnings.warn('No output format specified. Defaults to saving all. Please use "--maps", "--glb", or "--ply" to specify the output.') |
|
save_maps_ = save_glb_ = save_ply_ = True |
|
|
|
model = MoGeModel.from_pretrained(pretrained_model_name_or_path).to(device).eval() |
|
|
|
for image_path in (pbar := tqdm(image_paths, desc='Total images', disable=len(image_paths) <= 1)): |
|
image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB) |
|
height, width = image.shape[:2] |
|
if resize_to is not None: |
|
height, width = min(resize_to, int(resize_to * height / width)), min(resize_to, int(resize_to * width / height)) |
|
image = cv2.resize(image, (width, height), cv2.INTER_AREA) |
|
|
|
splitted_extrinsics, splitted_intriniscs = get_panorama_cameras() |
|
splitted_resolution = 512 |
|
splitted_images = split_panorama_image(image, splitted_extrinsics, splitted_intriniscs, splitted_resolution) |
|
|
|
|
|
print('Inferring...') if pbar.disable else pbar.set_postfix_str(f'Inferring') |
|
|
|
splitted_distance_maps, splitted_masks = [], [] |
|
for i in trange(0, len(splitted_images), batch_size, desc='Inferring splitted views', disable=len(splitted_images) <= batch_size, leave=False): |
|
image_tensor = torch.tensor(np.stack(splitted_images[i:i + batch_size]) / 255, dtype=torch.float32, device=device).permute(0, 3, 1, 2) |
|
fov_x, fov_y = np.rad2deg(utils3d.numpy.intrinsics_to_fov(np.array(splitted_intriniscs[i:i + batch_size]))) |
|
fov_x = torch.tensor(fov_x, dtype=torch.float32, device=device) |
|
output = model.infer(image_tensor, fov_x=fov_x, apply_mask=False) |
|
distance_map, mask = output['points'].norm(dim=-1).cpu().numpy(), output['mask'].cpu().numpy() |
|
splitted_distance_maps.extend(list(distance_map)) |
|
splitted_masks.extend(list(mask)) |
|
|
|
|
|
if save_splitted: |
|
splitted_save_path = Path(output_path, image_path.stem, 'splitted') |
|
splitted_save_path.mkdir(exist_ok=True, parents=True) |
|
for i in range(len(splitted_images)): |
|
cv2.imwrite(str(splitted_save_path / f'{i:02d}.jpg'), cv2.cvtColor(splitted_images[i], cv2.COLOR_RGB2BGR)) |
|
cv2.imwrite(str(splitted_save_path / f'{i:02d}_distance_vis.png'), cv2.cvtColor(colorize_depth(splitted_distance_maps[i], splitted_masks[i]), cv2.COLOR_RGB2BGR)) |
|
|
|
|
|
print('Merging...') if pbar.disable else pbar.set_postfix_str(f'Merging') |
|
|
|
merging_width, merging_height = min(1920, width), min(960, height) |
|
panorama_depth, panorama_mask = merge_panorama_depth(merging_width, merging_height, splitted_distance_maps, splitted_masks, splitted_extrinsics, splitted_intriniscs) |
|
panorama_depth = panorama_depth.astype(np.float32) |
|
panorama_depth = cv2.resize(panorama_depth, (width, height), cv2.INTER_LINEAR) |
|
panorama_mask = cv2.resize(panorama_mask.astype(np.uint8), (width, height), cv2.INTER_NEAREST) > 0 |
|
points = panorama_depth[:, :, None] * spherical_uv_to_directions(utils3d.numpy.image_uv(width=width, height=height)) |
|
|
|
|
|
print('Writing outputs...') if pbar.disable else pbar.set_postfix_str(f'Inferring') |
|
save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem) |
|
save_path.mkdir(exist_ok=True, parents=True) |
|
if save_maps_: |
|
cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) |
|
cv2.imwrite(str(save_path / 'depth_vis.png'), cv2.cvtColor(colorize_depth(panorama_depth, mask=panorama_mask), cv2.COLOR_RGB2BGR)) |
|
cv2.imwrite(str(save_path / 'depth.exr'), panorama_depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) |
|
cv2.imwrite(str(save_path / 'points.exr'), points, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) |
|
cv2.imwrite(str(save_path /'mask.png'), (panorama_mask * 255).astype(np.uint8)) |
|
|
|
|
|
if save_glb_ or save_ply_ or show: |
|
normals, normals_mask = utils3d.numpy.points_to_normals(points, panorama_mask) |
|
faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh( |
|
points, |
|
image.astype(np.float32) / 255, |
|
utils3d.numpy.image_uv(width=width, height=height), |
|
mask=panorama_mask & ~(utils3d.numpy.depth_edge(panorama_depth, rtol=threshold) & utils3d.numpy.normals_edge(normals, tol=5, mask=normals_mask)), |
|
tri=True |
|
) |
|
|
|
if save_glb_: |
|
save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image) |
|
|
|
if save_ply_: |
|
save_ply(save_path / 'mesh.ply', vertices, faces, vertex_colors) |
|
|
|
if show: |
|
trimesh.Trimesh( |
|
vertices=vertices, |
|
vertex_colors=vertex_colors, |
|
faces=faces, |
|
process=False |
|
).show() |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |