Spaces:
Running
on
Zero
Running
on
Zero
| import tempfile | |
| from contextlib import contextmanager | |
| from typing import Iterator, Optional, Union | |
| import blobfile as bf | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from shap_e.rendering.blender.render import render_mesh, render_model | |
| from shap_e.rendering.blender.view_data import BlenderViewData | |
| from shap_e.rendering.mesh import TriMesh | |
| from shap_e.rendering.point_cloud import PointCloud | |
| from shap_e.rendering.view_data import ViewData | |
| from shap_e.util.collections import AttrDict | |
| from shap_e.util.image_util import center_crop, get_alpha, remove_alpha, resize | |
| def load_or_create_multimodal_batch( | |
| device: torch.device, | |
| *, | |
| mesh_path: Optional[str] = None, | |
| model_path: Optional[str] = None, | |
| cache_dir: Optional[str] = None, | |
| point_count: int = 2**14, | |
| random_sample_count: int = 2**19, | |
| pc_num_views: int = 40, | |
| mv_light_mode: Optional[str] = None, | |
| mv_num_views: int = 20, | |
| mv_image_size: int = 512, | |
| mv_alpha_removal: str = "black", | |
| verbose: bool = False, | |
| ) -> AttrDict: | |
| if verbose: | |
| print("creating point cloud...") | |
| pc = load_or_create_pc( | |
| mesh_path=mesh_path, | |
| model_path=model_path, | |
| cache_dir=cache_dir, | |
| random_sample_count=random_sample_count, | |
| point_count=point_count, | |
| num_views=pc_num_views, | |
| verbose=verbose, | |
| ) | |
| raw_pc = np.concatenate([pc.coords, pc.select_channels(["R", "G", "B"])], axis=-1) | |
| encode_me = torch.from_numpy(raw_pc).float().to(device) | |
| batch = AttrDict(points=encode_me.t()[None]) | |
| if mv_light_mode: | |
| if verbose: | |
| print("creating multiview...") | |
| with load_or_create_multiview( | |
| mesh_path=mesh_path, | |
| model_path=model_path, | |
| cache_dir=cache_dir, | |
| num_views=mv_num_views, | |
| extract_material=False, | |
| light_mode=mv_light_mode, | |
| verbose=verbose, | |
| ) as mv: | |
| cameras, views, view_alphas, depths = [], [], [], [] | |
| for view_idx in range(mv.num_views): | |
| camera, view = mv.load_view( | |
| view_idx, | |
| ["R", "G", "B", "A"] if "A" in mv.channel_names else ["R", "G", "B"], | |
| ) | |
| depth = None | |
| if "D" in mv.channel_names: | |
| _, depth = mv.load_view(view_idx, ["D"]) | |
| depth = process_depth(depth, mv_image_size) | |
| view, alpha = process_image( | |
| np.round(view * 255.0).astype(np.uint8), mv_alpha_removal, mv_image_size | |
| ) | |
| camera = camera.center_crop().resize_image(mv_image_size, mv_image_size) | |
| cameras.append(camera) | |
| views.append(view) | |
| view_alphas.append(alpha) | |
| depths.append(depth) | |
| batch.depths = [depths] | |
| batch.views = [views] | |
| batch.view_alphas = [view_alphas] | |
| batch.cameras = [cameras] | |
| return normalize_input_batch(batch, pc_scale=2.0, color_scale=1.0 / 255.0) | |
| def load_or_create_pc( | |
| *, | |
| mesh_path: Optional[str], | |
| model_path: Optional[str], | |
| cache_dir: Optional[str], | |
| random_sample_count: int, | |
| point_count: int, | |
| num_views: int, | |
| verbose: bool = False, | |
| ) -> PointCloud: | |
| assert (model_path is not None) ^ ( | |
| mesh_path is not None | |
| ), "must specify exactly one of model_path or mesh_path" | |
| path = model_path if model_path is not None else mesh_path | |
| if cache_dir is not None: | |
| cache_path = bf.join( | |
| cache_dir, | |
| f"pc_{bf.basename(path)}_mat_{num_views}_{random_sample_count}_{point_count}.npz", | |
| ) | |
| if bf.exists(cache_path): | |
| return PointCloud.load(cache_path) | |
| else: | |
| cache_path = None | |
| with load_or_create_multiview( | |
| mesh_path=mesh_path, | |
| model_path=model_path, | |
| cache_dir=cache_dir, | |
| num_views=num_views, | |
| verbose=verbose, | |
| ) as mv: | |
| if verbose: | |
| print("extracting point cloud from multiview...") | |
| pc = mv_to_pc( | |
| multiview=mv, random_sample_count=random_sample_count, point_count=point_count | |
| ) | |
| if cache_path is not None: | |
| pc.save(cache_path) | |
| return pc | |
| def load_or_create_multiview( | |
| *, | |
| mesh_path: Optional[str], | |
| model_path: Optional[str], | |
| cache_dir: Optional[str], | |
| num_views: int = 20, | |
| extract_material: bool = True, | |
| light_mode: Optional[str] = None, | |
| verbose: bool = False, | |
| ) -> Iterator[BlenderViewData]: | |
| assert (model_path is not None) ^ ( | |
| mesh_path is not None | |
| ), "must specify exactly one of model_path or mesh_path" | |
| path = model_path if model_path is not None else mesh_path | |
| if extract_material: | |
| assert light_mode is None, "light_mode is ignored when extract_material=True" | |
| else: | |
| assert light_mode is not None, "must specify light_mode when extract_material=False" | |
| if cache_dir is not None: | |
| if extract_material: | |
| cache_path = bf.join(cache_dir, f"mv_{bf.basename(path)}_mat_{num_views}.zip") | |
| else: | |
| cache_path = bf.join(cache_dir, f"mv_{bf.basename(path)}_{light_mode}_{num_views}.zip") | |
| if bf.exists(cache_path): | |
| with bf.BlobFile(cache_path, "rb") as f: | |
| yield BlenderViewData(f) | |
| return | |
| else: | |
| cache_path = None | |
| common_kwargs = dict( | |
| fast_mode=True, | |
| extract_material=extract_material, | |
| camera_pose="random", | |
| light_mode=light_mode or "uniform", | |
| verbose=verbose, | |
| ) | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| tmp_path = bf.join(tmp_dir, "out.zip") | |
| if mesh_path is not None: | |
| mesh = TriMesh.load(mesh_path) | |
| render_mesh( | |
| mesh=mesh, | |
| output_path=tmp_path, | |
| num_images=num_views, | |
| backend="BLENDER_EEVEE", | |
| **common_kwargs, | |
| ) | |
| elif model_path is not None: | |
| render_model( | |
| model_path, | |
| output_path=tmp_path, | |
| num_images=num_views, | |
| backend="BLENDER_EEVEE", | |
| **common_kwargs, | |
| ) | |
| if cache_path is not None: | |
| bf.copy(tmp_path, cache_path) | |
| with bf.BlobFile(tmp_path, "rb") as f: | |
| yield BlenderViewData(f) | |
| def mv_to_pc(multiview: ViewData, random_sample_count: int, point_count: int) -> PointCloud: | |
| pc = PointCloud.from_rgbd(multiview) | |
| # Handle empty samples. | |
| if len(pc.coords) == 0: | |
| pc = PointCloud( | |
| coords=np.zeros([1, 3]), | |
| channels=dict(zip("RGB", np.zeros([3, 1]))), | |
| ) | |
| while len(pc.coords) < point_count: | |
| pc = pc.combine(pc) | |
| # Prevent duplicate points; some models may not like it. | |
| pc.coords += np.random.normal(size=pc.coords.shape) * 1e-4 | |
| pc = pc.random_sample(random_sample_count) | |
| pc = pc.farthest_point_sample(point_count, average_neighbors=True) | |
| return pc | |
| def normalize_input_batch(batch: AttrDict, *, pc_scale: float, color_scale: float) -> AttrDict: | |
| res = batch.copy() | |
| scale_vec = torch.tensor([*([pc_scale] * 3), *([color_scale] * 3)], device=batch.points.device) | |
| res.points = res.points * scale_vec[:, None] | |
| if "cameras" in res: | |
| res.cameras = [[cam.scale_scene(pc_scale) for cam in cams] for cams in res.cameras] | |
| if "depths" in res: | |
| res.depths = [[depth * pc_scale for depth in depths] for depths in res.depths] | |
| return res | |
| def process_depth(depth_img: np.ndarray, image_size: int) -> np.ndarray: | |
| depth_img = center_crop(depth_img) | |
| depth_img = resize(depth_img, width=image_size, height=image_size) | |
| return np.squeeze(depth_img) | |
| def process_image( | |
| img_or_img_arr: Union[Image.Image, np.ndarray], alpha_removal: str, image_size: int | |
| ): | |
| if isinstance(img_or_img_arr, np.ndarray): | |
| img = Image.fromarray(img_or_img_arr) | |
| img_arr = img_or_img_arr | |
| else: | |
| img = img_or_img_arr | |
| img_arr = np.array(img) | |
| if len(img_arr.shape) == 2: | |
| # Grayscale | |
| rgb = Image.new("RGB", img.size) | |
| rgb.paste(img) | |
| img = rgb | |
| img_arr = np.array(img) | |
| img = center_crop(img) | |
| alpha = get_alpha(img) | |
| img = remove_alpha(img, mode=alpha_removal) | |
| alpha = alpha.resize((image_size,) * 2, resample=Image.BILINEAR) | |
| img = img.resize((image_size,) * 2, resample=Image.BILINEAR) | |
| return img, alpha | |