Commit
·
9d0b3b4
1
Parent(s):
2053232
Refactor background removal process in app.py to utilize rembg library, enhancing performance and simplifying the code. Update device handling to allow dynamic selection between CPU and CUDA, improving compatibility across different hardware configurations. Modify output format from OBJ to GLB for better integration with Gradio display.
Browse files- app.py +22 -62
- inference.py +35 -126
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import spaces
|
2 |
import argparse
|
3 |
import numpy as np
|
@@ -9,23 +10,18 @@ import PIL
|
|
9 |
from pipelines import TwoStagePipeline
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
import os
|
|
|
12 |
from typing import Any
|
13 |
import json
|
14 |
import os
|
15 |
import json
|
16 |
import argparse
|
17 |
-
import requests
|
18 |
-
import tempfile
|
19 |
|
20 |
from model import CRM
|
21 |
from inference import generate3d
|
22 |
-
from dis_bg_remover import remove_background as dis_remove_background
|
23 |
-
|
24 |
-
# Configurable ONNX model path (can be set via environment variable)
|
25 |
-
DIS_ONNX_MODEL_PATH = os.environ.get("DIS_ONNX_MODEL_PATH", "isnet_dis.onnx")
|
26 |
-
DIS_ONNX_MODEL_URL = "https://huggingface.co/stoned0651/isnet_dis.onnx/resolve/main/isnet_dis.onnx"
|
27 |
|
28 |
pipeline = None
|
|
|
29 |
|
30 |
|
31 |
def expand_to_square(image, bg_color=(0, 0, 0, 0)):
|
@@ -44,49 +40,23 @@ def check_input_image(input_image):
|
|
44 |
raise gr.Error("No image uploaded!")
|
45 |
|
46 |
|
47 |
-
def ensure_dis_onnx_model():
|
48 |
-
if not os.path.exists(DIS_ONNX_MODEL_PATH):
|
49 |
-
try:
|
50 |
-
print(f"Model file not found at {DIS_ONNX_MODEL_PATH}. Downloading from {DIS_ONNX_MODEL_URL}...")
|
51 |
-
response = requests.get(DIS_ONNX_MODEL_URL, stream=True)
|
52 |
-
response.raise_for_status()
|
53 |
-
with open(DIS_ONNX_MODEL_PATH, "wb") as f:
|
54 |
-
for chunk in response.iter_content(chunk_size=8192):
|
55 |
-
if chunk:
|
56 |
-
f.write(chunk)
|
57 |
-
print(f"Downloaded model to {DIS_ONNX_MODEL_PATH}")
|
58 |
-
except Exception as e:
|
59 |
-
raise gr.Error(
|
60 |
-
f"Failed to download DIS background remover model file: {e}\n"
|
61 |
-
f"Please manually download it from {DIS_ONNX_MODEL_URL} and place it in the project directory or set the DIS_ONNX_MODEL_PATH environment variable."
|
62 |
-
)
|
63 |
-
|
64 |
-
|
65 |
def remove_background(
|
66 |
image: PIL.Image.Image,
|
67 |
rembg_session: Any = None,
|
68 |
force: bool = False,
|
69 |
**rembg_kwargs,
|
70 |
) -> PIL.Image.Image:
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
mask = mask[..., 0]
|
83 |
-
# Convert original image to RGBA
|
84 |
-
image = image.convert("RGBA")
|
85 |
-
image_np = np.array(image)
|
86 |
-
image_np[..., 3] = mask
|
87 |
-
return Image.fromarray(image_np)
|
88 |
-
# If extracted_img is already a color image, just return it
|
89 |
-
return extracted_img
|
90 |
|
91 |
def do_resize_content(original_image: Image, scale_rate):
|
92 |
# resize image content wile retain the original image size
|
@@ -118,9 +88,7 @@ def preprocess_image(image, background_choice, foreground_ratio, backgroud_color
|
|
118 |
background = Image.new("RGBA", image.size, (0, 0, 0, 0))
|
119 |
image = Image.alpha_composite(background, image)
|
120 |
else:
|
121 |
-
image = remove_background(image, force=True)
|
122 |
-
if image is None:
|
123 |
-
raise gr.Error("Background removal failed. Please check the input image and ensure the model file exists and is valid.")
|
124 |
image = do_resize_content(image, foreground_ratio)
|
125 |
image = expand_to_square(image)
|
126 |
image = add_background(image, backgroud_color)
|
@@ -154,20 +122,14 @@ parser.add_argument(
|
|
154 |
help="config for stage2",
|
155 |
)
|
156 |
|
157 |
-
|
158 |
-
parser.add_argument("--device", type=str, default="cpu")
|
159 |
args = parser.parse_args()
|
160 |
|
161 |
-
if not torch.cuda.is_available():
|
162 |
-
raise RuntimeError("CUDA is not available! Please check your GPU and CUDA installation.")
|
163 |
-
|
164 |
-
device = torch.device("cuda")
|
165 |
-
|
166 |
crm_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="CRM.pth")
|
167 |
specs = json.load(open("configs/specs_objaverse_total.json"))
|
168 |
model = CRM(specs)
|
169 |
-
model.load_state_dict(torch.load(crm_path, map_location="
|
170 |
-
model = model.to(
|
171 |
|
172 |
stage1_config = OmegaConf.load(args.stage1_config).config
|
173 |
stage2_config = OmegaConf.load(args.stage2_config).config
|
@@ -187,7 +149,7 @@ pipeline = TwoStagePipeline(
|
|
187 |
stage2_model_config,
|
188 |
stage1_sampler_config,
|
189 |
stage2_sampler_config,
|
190 |
-
device=
|
191 |
dtype=torch.float32
|
192 |
)
|
193 |
|
@@ -243,10 +205,8 @@ with gr.Blocks() as demo:
|
|
243 |
image_output = gr.Image(interactive=False, label="Output RGB image")
|
244 |
xyz_ouput = gr.Image(interactive=False, label="Output CCM image")
|
245 |
|
246 |
-
output_model = gr.Model3D(
|
247 |
-
|
248 |
-
interactive=False,
|
249 |
-
)
|
250 |
gr.Markdown("Note: Ensure that the input image is correctly pre-processed into a grey background, otherwise the results will be unpredictable.")
|
251 |
|
252 |
inputs = [
|
@@ -272,4 +232,4 @@ with gr.Blocks() as demo:
|
|
272 |
inputs=inputs,
|
273 |
outputs=outputs,
|
274 |
)
|
275 |
-
demo.queue().launch()
|
|
|
1 |
+
# Not ready to use yet
|
2 |
import spaces
|
3 |
import argparse
|
4 |
import numpy as np
|
|
|
10 |
from pipelines import TwoStagePipeline
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
import os
|
13 |
+
import rembg
|
14 |
from typing import Any
|
15 |
import json
|
16 |
import os
|
17 |
import json
|
18 |
import argparse
|
|
|
|
|
19 |
|
20 |
from model import CRM
|
21 |
from inference import generate3d
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
pipeline = None
|
24 |
+
rembg_session = rembg.new_session()
|
25 |
|
26 |
|
27 |
def expand_to_square(image, bg_color=(0, 0, 0, 0)):
|
|
|
40 |
raise gr.Error("No image uploaded!")
|
41 |
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
def remove_background(
|
44 |
image: PIL.Image.Image,
|
45 |
rembg_session: Any = None,
|
46 |
force: bool = False,
|
47 |
**rembg_kwargs,
|
48 |
) -> PIL.Image.Image:
|
49 |
+
do_remove = True
|
50 |
+
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
|
51 |
+
# explain why current do not rm bg
|
52 |
+
print("alhpa channl not enpty, skip remove background, using alpha channel as mask")
|
53 |
+
background = Image.new("RGBA", image.size, (0, 0, 0, 0))
|
54 |
+
image = Image.alpha_composite(background, image)
|
55 |
+
do_remove = False
|
56 |
+
do_remove = do_remove or force
|
57 |
+
if do_remove:
|
58 |
+
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
|
59 |
+
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
def do_resize_content(original_image: Image, scale_rate):
|
62 |
# resize image content wile retain the original image size
|
|
|
88 |
background = Image.new("RGBA", image.size, (0, 0, 0, 0))
|
89 |
image = Image.alpha_composite(background, image)
|
90 |
else:
|
91 |
+
image = remove_background(image, rembg_session, force=True)
|
|
|
|
|
92 |
image = do_resize_content(image, foreground_ratio)
|
93 |
image = expand_to_square(image)
|
94 |
image = add_background(image, backgroud_color)
|
|
|
122 |
help="config for stage2",
|
123 |
)
|
124 |
|
125 |
+
parser.add_argument("--device", type=str, default="cuda")
|
|
|
126 |
args = parser.parse_args()
|
127 |
|
|
|
|
|
|
|
|
|
|
|
128 |
crm_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="CRM.pth")
|
129 |
specs = json.load(open("configs/specs_objaverse_total.json"))
|
130 |
model = CRM(specs)
|
131 |
+
model.load_state_dict(torch.load(crm_path, map_location="cpu"), strict=False)
|
132 |
+
model = model.to(args.device)
|
133 |
|
134 |
stage1_config = OmegaConf.load(args.stage1_config).config
|
135 |
stage2_config = OmegaConf.load(args.stage2_config).config
|
|
|
149 |
stage2_model_config,
|
150 |
stage1_sampler_config,
|
151 |
stage2_sampler_config,
|
152 |
+
device=args.device,
|
153 |
dtype=torch.float32
|
154 |
)
|
155 |
|
|
|
205 |
image_output = gr.Image(interactive=False, label="Output RGB image")
|
206 |
xyz_ouput = gr.Image(interactive=False, label="Output CCM image")
|
207 |
|
208 |
+
output_model = gr.Model3D(label="Output GLB", clear_color=[1, 1, 1, 0])
|
209 |
+
|
|
|
|
|
210 |
gr.Markdown("Note: Ensure that the input image is correctly pre-processed into a grey background, otherwise the results will be unpredictable.")
|
211 |
|
212 |
inputs = [
|
|
|
232 |
inputs=inputs,
|
233 |
outputs=outputs,
|
234 |
)
|
235 |
+
demo.queue().launch()
|
inference.py
CHANGED
@@ -1,130 +1,39 @@
|
|
1 |
-
import
|
2 |
import torch
|
3 |
-
import
|
4 |
-
|
5 |
-
from util.utils import get_tri
|
6 |
-
import tempfile
|
7 |
from mesh import Mesh
|
8 |
-
import
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
def vertex_color_to_uv_textured_glb(obj_path, glb_path, texture_size=512):
|
16 |
-
mesh = trimesh.load(obj_path, process=False)
|
17 |
-
vertex_colors = mesh.visual.vertex_colors[:, :3] # (N, 3), uint8
|
18 |
-
# Generate UVs
|
19 |
-
vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
|
20 |
-
vertices = mesh.vertices[vmapping]
|
21 |
-
vertex_colors = vertex_colors[vmapping]
|
22 |
-
mesh.vertices = vertices
|
23 |
-
mesh.faces = indices
|
24 |
-
# Bake texture (hybrid: per-pixel barycentric for accuracy)
|
25 |
-
buffer_size = texture_size * 2
|
26 |
-
texture_buffer = np.zeros((buffer_size, buffer_size, 4), dtype=np.uint8)
|
27 |
-
face_uvs = uvs[mesh.faces]
|
28 |
-
face_colors = vertex_colors[mesh.faces]
|
29 |
-
min_xy = np.floor(np.min(face_uvs, axis=1) * (buffer_size - 1)).astype(int)
|
30 |
-
max_xy = np.ceil(np.max(face_uvs, axis=1) * (buffer_size - 1)).astype(int)
|
31 |
-
for i in range(len(mesh.faces)):
|
32 |
-
uv0, uv1, uv2 = face_uvs[i]
|
33 |
-
c0, c1, c2 = face_colors[i]
|
34 |
-
min_x, min_y = min_xy[i]
|
35 |
-
max_x, max_y = max_xy[i]
|
36 |
-
for y in range(min_y, max_y + 1):
|
37 |
-
for x in range(min_x, max_x + 1):
|
38 |
-
p = np.array([x + 0.5, y + 0.5]) / (buffer_size - 1)
|
39 |
-
# Barycentric coordinates
|
40 |
-
v0, v1, v2 = uv0, uv1, uv2
|
41 |
-
denom = (v1[1] - v2[1]) * (v0[0] - v2[0]) + (v2[0] - v1[0]) * (v0[1] - v2[1])
|
42 |
-
if denom == 0:
|
43 |
-
continue
|
44 |
-
u = ((v1[1] - v2[1]) * (p[0] - v2[0]) + (v2[0] - v1[0]) * (p[1] - v2[1])) / denom
|
45 |
-
v = ((v2[1] - v0[1]) * (p[0] - v2[0]) + (v0[0] - v2[0]) * (p[1] - v2[1])) / denom
|
46 |
-
w = 1 - u - v
|
47 |
-
if (u >= 0) and (v >= 0) and (w >= 0):
|
48 |
-
color = u * c0 + v * c1 + w * c2
|
49 |
-
texture_buffer[y, x, :3] = np.clip(color, 0, 255).astype(np.uint8)
|
50 |
-
texture_buffer[y, x, 3] = 255
|
51 |
-
# Inpainting, filtering, and downsampling (keep optimized)
|
52 |
-
image_bgra = texture_buffer.copy()
|
53 |
-
mask = (image_bgra[:, :, 3] == 0).astype(np.uint8) * 255
|
54 |
-
image_bgr = cv2.cvtColor(image_bgra, cv2.COLOR_BGRA2BGR)
|
55 |
-
inpainted_bgr = cv2.inpaint(image_bgr, mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
|
56 |
-
inpainted_bgra = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2BGRA)
|
57 |
-
texture_buffer = inpainted_bgra[::-1]
|
58 |
-
image_texture = Image.fromarray(texture_buffer)
|
59 |
-
image_texture = image_texture.filter(ImageFilter.MedianFilter(size=3))
|
60 |
-
image_texture = image_texture.filter(ImageFilter.GaussianBlur(radius=1))
|
61 |
-
image_texture = image_texture.resize((texture_size, texture_size), Image.LANCZOS)
|
62 |
-
# Assign UVs and texture to mesh
|
63 |
-
material = trimesh.visual.material.PBRMaterial(
|
64 |
-
baseColorFactor=[1.0, 1.0, 1.0, 1.0],
|
65 |
-
baseColorTexture=image_texture,
|
66 |
-
metallicFactor=0.0,
|
67 |
-
roughnessFactor=1.0,
|
68 |
-
)
|
69 |
-
visuals = trimesh.visual.TextureVisuals(uv=uvs, material=material)
|
70 |
-
mesh.visual = visuals
|
71 |
-
mesh.export(glb_path)
|
72 |
-
image_texture.save("debug_texture.png")
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
xyz_tri = torch.from_numpy(ccm[:,:,(2,1,0)]).to(device)/255
|
80 |
-
color = color_tri.permute(2,0,1)
|
81 |
-
xyz = xyz_tri.permute(2,0,1)
|
82 |
-
def get_imgs(color):
|
83 |
-
color_list = []
|
84 |
-
color_list.append(color[:,:,256*5:256*(1+5)])
|
85 |
-
for i in range(0,5):
|
86 |
-
color_list.append(color[:,:,256*i:256*(1+i)])
|
87 |
-
return torch.stack(color_list, dim=0)
|
88 |
-
triplane_color = get_imgs(color).permute(0,2,3,1).unsqueeze(0).to(device)
|
89 |
-
color = get_imgs(color)
|
90 |
-
xyz = get_imgs(xyz)
|
91 |
-
color = get_tri(color, dim=0, blender= True, scale = 1).unsqueeze(0).to(device)
|
92 |
-
xyz = get_tri(xyz, dim=0, blender= True, scale = 1, fix= True).unsqueeze(0).to(device)
|
93 |
-
triplane = torch.cat([color,xyz],dim=1).to(device)
|
94 |
-
model.eval()
|
95 |
-
if model.denoising == True:
|
96 |
-
tnew = 20
|
97 |
-
tnew = torch.randint(tnew, tnew+1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
|
98 |
-
noise_new = torch.randn_like(triplane) *0.5+0.5
|
99 |
-
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
|
100 |
-
start_time = time.time()
|
101 |
-
with torch.no_grad():
|
102 |
-
triplane_feature2 = model.unet2(triplane,tnew)
|
103 |
-
end_time = time.time()
|
104 |
-
elapsed_time = end_time - start_time
|
105 |
-
print(f"unet takes {elapsed_time}s")
|
106 |
-
else:
|
107 |
-
triplane_feature2 = model.unet2(triplane)
|
108 |
-
with torch.no_grad():
|
109 |
-
data_config = {
|
110 |
-
'resolution': [1024, 1024],
|
111 |
-
"triview_color": triplane_color.to(device),
|
112 |
-
}
|
113 |
-
verts, faces = model.decode(data_config, triplane_feature2)
|
114 |
-
data_config['verts'] = verts[0]
|
115 |
-
data_config['faces'] = faces
|
116 |
-
from kiui.mesh_utils import clean_mesh
|
117 |
-
verts, faces = clean_mesh(data_config['verts'].squeeze().cpu().numpy().astype(np.float32), data_config['faces'].squeeze().cpu().numpy().astype(np.int32), repair = False, remesh=True, remesh_size=0.005, remesh_iters=1)
|
118 |
-
data_config['verts'] = torch.from_numpy(verts).to(device).contiguous()
|
119 |
-
data_config['faces'] = torch.from_numpy(faces).to(device).contiguous()
|
120 |
-
start_time = time.time()
|
121 |
-
with torch.no_grad():
|
122 |
-
mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
|
123 |
-
model.export_mesh(data_config, mesh_path_glb, tri_fea_2 = triplane_feature2)
|
124 |
-
end_time = time.time()
|
125 |
-
elapsed_time = end_time - start_time
|
126 |
-
print(f"uv takes {elapsed_time}s")
|
127 |
-
obj_path = mesh_path_glb + ".obj"
|
128 |
-
glb_path = mesh_path_glb + ".glb"
|
129 |
-
vertex_color_to_uv_textured_glb(obj_path, glb_path)
|
130 |
-
return glb_path
|
|
|
1 |
+
import os
|
2 |
import torch
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
|
|
|
|
5 |
from mesh import Mesh
|
6 |
+
from pipelines.pipeline_text_to_3d import TextTo3D
|
7 |
+
|
8 |
+
|
9 |
+
# === Load Model (assumes this is done once at startup, not per request) ===
|
10 |
+
model = TextTo3D.from_pretrained("./checkpoints/zeroscope_v1_5")
|
11 |
+
model.to(torch.device("cpu"))
|
12 |
+
model.eval()
|
13 |
+
|
14 |
+
def generate3d(prompt: str, guidance_scale: float = 15.0, steps: int = 50) -> str:
|
15 |
+
# === Set up paths ===
|
16 |
+
output_dir = "outputs"
|
17 |
+
os.makedirs(output_dir, exist_ok=True)
|
18 |
+
base_name = prompt.replace(" ", "_").lower()
|
19 |
+
mesh_path_base = os.path.join(output_dir, base_name)
|
20 |
+
|
21 |
+
# === Generate 3D Mesh ===
|
22 |
+
mesh = model(prompt, guidance_scale=guidance_scale, steps=steps)
|
23 |
+
obj_path = mesh_path_base + ".obj"
|
24 |
+
mesh.export_mesh_wt_uv(obj_path)
|
25 |
+
|
26 |
+
# === Convert to GLB with textures ===
|
27 |
+
mesh_loaded = Mesh.load(obj_path, device=torch.device("cpu"))
|
28 |
+
glb_path = mesh_path_base + ".glb"
|
29 |
+
mesh_loaded.write(glb_path)
|
30 |
+
|
31 |
+
# === Return GLB path for Gradio display ===
|
32 |
+
return glb_path
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
+
if __name__ == "__main__":
|
36 |
+
# Example run
|
37 |
+
prompt = "a modern wooden chair"
|
38 |
+
output_glb = generate3d(prompt)
|
39 |
+
print(f"Generated GLB: {output_glb}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|