|
import os |
|
from os.path import basename, splitext, join |
|
import tempfile |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
import cv2 |
|
from torchvision.transforms.functional import to_tensor, to_pil_image |
|
from torch import Tensor |
|
from genstereo import GenStereo, AdaptiveFusionLayer |
|
import ssl |
|
from huggingface_hub import hf_hub_download |
|
import spaces |
|
|
|
from extern.DAM2.depth_anything_v2.dpt import DepthAnythingV2 |
|
ssl._create_default_https_context = ssl._create_unverified_context |
|
|
|
IMAGE_SIZE = 512 |
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' |
|
CHECKPOINT_NAME = 'genstereo' |
|
|
|
def download_models(): |
|
models = [ |
|
{ |
|
'repo': 'stabilityai/sd-vae-ft-mse', |
|
'sub': None, |
|
'dst': 'checkpoints/sd-vae-ft-mse', |
|
'files': ['config.json', 'diffusion_pytorch_model.safetensors'], |
|
'token': None |
|
}, |
|
{ |
|
'repo': 'lambdalabs/sd-image-variations-diffusers', |
|
'sub': 'image_encoder', |
|
'dst': 'checkpoints', |
|
'files': ['config.json', 'pytorch_model.bin'], |
|
'token': None |
|
}, |
|
{ |
|
'repo': 'FQiao/GenStereo', |
|
'sub': None, |
|
'dst': 'checkpoints/genstereo', |
|
'files': ['config.json', 'denoising_unet.pth', 'fusion_layer.pth', 'pose_guider.pth', 'reference_unet.pth'], |
|
'token': None |
|
}, |
|
{ |
|
'repo': 'depth-anything/Depth-Anything-V2-Large', |
|
'sub': None, |
|
'dst': 'checkpoints', |
|
'files': [f'depth_anything_v2_vitl.pth'], |
|
'token': None |
|
} |
|
] |
|
|
|
for model in models: |
|
for file in model['files']: |
|
hf_hub_download( |
|
repo_id=model['repo'], |
|
subfolder=model['sub'], |
|
filename=file, |
|
local_dir=model['dst'], |
|
token=model['token'] |
|
) |
|
|
|
|
|
download_models() |
|
|
|
|
|
def get_dam2_model(): |
|
model_configs = { |
|
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, |
|
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, |
|
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, |
|
} |
|
|
|
encoder = 'vitl' |
|
encoder_size_map = {'vits': 'Small', 'vitb': 'Base', 'vitl': 'Large'} |
|
|
|
if encoder not in encoder_size_map: |
|
raise ValueError(f"Unsupported encoder: {encoder}. Supported: {list(encoder_size_map.keys())}") |
|
|
|
dam2 = DepthAnythingV2(**model_configs[encoder]) |
|
dam2_checkpoint = f'checkpoints/depth_anything_v2_{encoder}.pth' |
|
dam2.load_state_dict(torch.load(dam2_checkpoint, map_location='cpu')) |
|
dam2 = dam2.to(DEVICE).eval() |
|
return dam2 |
|
|
|
|
|
def get_genstereo_model(): |
|
genwarp_cfg = dict( |
|
pretrained_model_path='checkpoints', |
|
checkpoint_name=CHECKPOINT_NAME, |
|
half_precision_weights=True |
|
) |
|
genstereo = GenStereo(cfg=genwarp_cfg, device=DEVICE) |
|
return genstereo |
|
|
|
|
|
def get_fusion_model(): |
|
fusion_model = AdaptiveFusionLayer() |
|
fusion_checkpoint = join('checkpoints', CHECKPOINT_NAME, 'fusion_layer.pth') |
|
fusion_model.load_state_dict(torch.load(fusion_checkpoint, map_location='cpu')) |
|
fusion_model = fusion_model.to(DEVICE).eval() |
|
return fusion_model |
|
|
|
|
|
def crop(img: Image) -> Image: |
|
W, H = img.size |
|
if W < H: |
|
left, right = 0, W |
|
top, bottom = np.ceil((H - W) / 2.), np.floor((H - W) / 2.) + W |
|
else: |
|
left, right = np.ceil((W - H) / 2.), np.floor((W - H) / 2.) + H |
|
top, bottom = 0, H |
|
return img.crop((left, top, right, bottom)) |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
with gr.Blocks( |
|
title='StereoGen Demo', |
|
css='img {display: inline;}' |
|
) as demo: |
|
|
|
src_image = gr.State() |
|
src_depth = gr.State() |
|
|
|
def normalize_disp(disp): |
|
return (disp - disp.min()) / (disp.max() - disp.min()) |
|
|
|
|
|
@spaces.GPU() |
|
def cb_mde(image_file: str): |
|
if not image_file: |
|
|
|
return None, None, None, None |
|
|
|
image = crop(Image.open(image_file).convert('RGB')) |
|
image = image.resize((IMAGE_SIZE, IMAGE_SIZE)) |
|
|
|
image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
|
|
|
dam2 = get_dam2_model() |
|
depth_dam2 = dam2.infer_image(image_bgr) |
|
depth = torch.tensor(depth_dam2).unsqueeze(0).unsqueeze(0).float() |
|
|
|
depth_image = cv2.applyColorMap((normalize_disp(depth_dam2) * 255).astype(np.uint8), cv2.COLORMAP_JET) |
|
|
|
return image, depth_image, image, depth |
|
|
|
@spaces.GPU() |
|
def cb_generate(image, depth: Tensor, scale_factor): |
|
norm_disp = normalize_disp(depth.cuda()) |
|
disp = norm_disp * scale_factor / 100 * IMAGE_SIZE |
|
|
|
genstereo = get_genstereo_model() |
|
fusion_model = get_fusion_model() |
|
|
|
renders = genstereo( |
|
src_image=image, |
|
src_disparity=disp, |
|
ratio=None, |
|
) |
|
warped = (renders['warped'] + 1) / 2 |
|
|
|
synthesized = renders['synthesized'] |
|
mask = renders['mask'] |
|
fusion_image = fusion_model(synthesized.float(), warped.float(), mask.float()) |
|
|
|
warped_pil = to_pil_image(warped[0]) |
|
fusion_pil = to_pil_image(fusion_image[0]) |
|
|
|
return warped_pil, fusion_pil |
|
|
|
|
|
gr.Markdown( |
|
""" |
|
# GenWarp: Single Image to Novel Views with Semantic-Preserving Generative Warping |
|
[](https://genwarp-nvs.github.io/) |
|
[](https://huggingface.co/spaces/Sony/GenWarp) |
|
[](https://github.com/sony/genwarp/) |
|
[](https://huggingface.co/Sony/genwarp) |
|
[](https://arxiv.org/abs/2405.17251) |
|
|
|
## Introduction |
|
This is an official demo for the paper "[GenWarp: Single Image to Novel Views with Semantic-Preserving Generative Warping](https://genwarp-nvs.github.io/)". Genwarp can generate novel view images from a single input conditioned on camera poses. In this demo, we offer a basic use of inference of the model. For detailed information, please refer to the [paper](https://arxiv.org/abs/2405.17251). |
|
|
|
## How to Use |
|
|
|
### Try examples |
|
- Examples are in the bottom section of the page |
|
|
|
### Upload your own images |
|
1. Upload a reference image to "Reference Input" |
|
2. Move the camera to your desired view in "Unprojected 3DGS" 3D viewer |
|
3. Hit "Generate a novel view" button and check the result |
|
|
|
## Tips |
|
- This model is mainly trained for indoor/outdoor scenery. It might not work well for object-centric inputs. For details on training the model, please check our [paper](https://arxiv.org/abs/2405.17251). |
|
- Extremely large camera movement from the input view might cause low performance results due to the unexpected deviation from the training distribution, which is not the scope of this model. Instead, you can feed the generation result for the small camera movement repeatedly and progressively move towards a desired view. |
|
- 3D viewer might take some time to update especially when trying different images back to back. Wait until it fully updates to the new image. |
|
|
|
""" |
|
) |
|
file = gr.File(label='Left', file_types=['image']) |
|
examples = gr.Examples( |
|
examples=['./assets/COCO_val2017_000000070229.jpg', |
|
'./assets/COCO_val2017_000000092839.jpg', |
|
'./assets/KITTI2015_000003_10.png', |
|
'./assets/KITTI2015_000147_10.png'], |
|
inputs=file |
|
) |
|
with gr.Row(): |
|
image_widget = gr.Image( |
|
label='Depth', type='filepath', |
|
interactive=False |
|
) |
|
depth_widget = gr.Image(label='Estimated Depth', type='pil') |
|
|
|
|
|
scale_slider = gr.Slider( |
|
label='Scale Factor', |
|
minimum=1.0, |
|
maximum=30.0, |
|
value=15.0, |
|
step=0.1, |
|
) |
|
|
|
button = gr.Button('Generate a right image', size='lg', variant='primary') |
|
with gr.Row(): |
|
warped_widget = gr.Image( |
|
label='Warped Image', type='pil', interactive=False |
|
) |
|
gen_widget = gr.Image( |
|
label='Generated Right', type='pil', interactive=False |
|
) |
|
|
|
|
|
file.change( |
|
fn=cb_mde, |
|
inputs=file, |
|
outputs=[image_widget, depth_widget, src_image, src_depth] |
|
) |
|
button.click( |
|
fn=cb_generate, |
|
inputs=[src_image, src_depth, scale_slider], |
|
outputs=[warped_widget, gen_widget] |
|
) |
|
|
|
if __name__ == '__main__': |
|
demo.launch() |