jadechoghari's picture
jadechoghari HF Staff
Create app.py
7cbe23c verified
raw
history blame
3.49 kB
import torch
import gradio as gr
import os
import numpy as np
import trimesh
import mcubes
from torchvision.utils import save_image
from PIL import Image
from transformers import AutoModel, AutoConfig
from rembg import remove, new_session
from functools import partial
from kiui.op import recenter
import kiui
# we load the pre-trained model from HF
class LRMGeneratorWrapper:
def __init__(self):
self.config = AutoConfig.from_pretrained("jadechoghari/custom-llrm", trust_remote_code=True)
self.model = AutoModel.from_pretrained("jadechoghari/custom-llrm", trust_remote_code=True)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.model.eval()
def forward(self, image, camera):
return self.model(image, camera)
model_wrapper = LRMGeneratorWrapper()
def preprocess_image(image, source_size):
session = new_session("isnet-general-use")
rembg_remove = partial(remove, session=session)
image = np.array(image)
image = rembg_remove(image)
mask = rembg_remove(image, only_mask=True)
image = recenter(image, mask, border_ratio=0.20)
image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0) / 255.0
if image.shape[1] == 4:
image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...])
image = torch.nn.functional.interpolate(image, size=(source_size, source_size), mode='bicubic', align_corners=True)
image = torch.clamp(image, 0, 1)
return image
#Ref: https://github.com/jadechoghari/vfusion3d/blob/main/lrm/inferrer.py
def generate_mesh(image, source_size=512, render_size=384, mesh_size=512, export_mesh=True):
image = preprocess_image(image, source_size).to(model_wrapper.device)
# TODO: make sure source_camero have the right shape and value
source_camera = torch.tensor([[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]], dtype=torch.float32).to(model_wrapper.device)
render_camera = torch.tensor([[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]], dtype=torch.float32).to(model_wrapper.device)
with torch.no_grad():
planes = model_wrapper.forward(image, source_camera)
if export_mesh:
grid_out = model_wrapper.model.synthesizer.forward_grid(planes=planes, grid_size=mesh_size)
vtx, faces = mcubes.marching_cubes(grid_out['sigma'].float().squeeze(0).squeeze(-1).cpu().numpy(), 1.0)
vtx = vtx / (mesh_size - 1) * 2 - 1
vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=model_wrapper.device).unsqueeze(0)
vtx_colors = model_wrapper.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].float().squeeze(0).cpu().numpy()
vtx_colors = (vtx_colors * 255).astype(np.uint8)
mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
mesh_path = "awesome_mesh.obj"
mesh.export(mesh_path, 'obj')
return mesh_path
# TODO: instead of outputting .obj file -> directly output a 3d model
def gradio_interface(image):
mesh_file = generate_mesh(image)
print("Generated Mesh File Path:", mesh_file)
return mesh_file
gr.Interface(
fn=gradio_interface,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=gr.File(label="Awesome 3D Mesh (.obj)"),
title="3D Mesh Generator by FacebookAI",
description="Upload an image and generate a 3D mesh (.obj) file using VFusion3D by FacebookAI"
).launch()