Spaces:
mashroo
/
Runtime error

YoussefAnso commited on
Commit
9d0b3b4
·
1 Parent(s): 2053232

Refactor background removal process in app.py to utilize rembg library, enhancing performance and simplifying the code. Update device handling to allow dynamic selection between CPU and CUDA, improving compatibility across different hardware configurations. Modify output format from OBJ to GLB for better integration with Gradio display.

Browse files
Files changed (2) hide show
  1. app.py +22 -62
  2. inference.py +35 -126
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import spaces
2
  import argparse
3
  import numpy as np
@@ -9,23 +10,18 @@ import PIL
9
  from pipelines import TwoStagePipeline
10
  from huggingface_hub import hf_hub_download
11
  import os
 
12
  from typing import Any
13
  import json
14
  import os
15
  import json
16
  import argparse
17
- import requests
18
- import tempfile
19
 
20
  from model import CRM
21
  from inference import generate3d
22
- from dis_bg_remover import remove_background as dis_remove_background
23
-
24
- # Configurable ONNX model path (can be set via environment variable)
25
- DIS_ONNX_MODEL_PATH = os.environ.get("DIS_ONNX_MODEL_PATH", "isnet_dis.onnx")
26
- DIS_ONNX_MODEL_URL = "https://huggingface.co/stoned0651/isnet_dis.onnx/resolve/main/isnet_dis.onnx"
27
 
28
  pipeline = None
 
29
 
30
 
31
  def expand_to_square(image, bg_color=(0, 0, 0, 0)):
@@ -44,49 +40,23 @@ def check_input_image(input_image):
44
  raise gr.Error("No image uploaded!")
45
 
46
 
47
- def ensure_dis_onnx_model():
48
- if not os.path.exists(DIS_ONNX_MODEL_PATH):
49
- try:
50
- print(f"Model file not found at {DIS_ONNX_MODEL_PATH}. Downloading from {DIS_ONNX_MODEL_URL}...")
51
- response = requests.get(DIS_ONNX_MODEL_URL, stream=True)
52
- response.raise_for_status()
53
- with open(DIS_ONNX_MODEL_PATH, "wb") as f:
54
- for chunk in response.iter_content(chunk_size=8192):
55
- if chunk:
56
- f.write(chunk)
57
- print(f"Downloaded model to {DIS_ONNX_MODEL_PATH}")
58
- except Exception as e:
59
- raise gr.Error(
60
- f"Failed to download DIS background remover model file: {e}\n"
61
- f"Please manually download it from {DIS_ONNX_MODEL_URL} and place it in the project directory or set the DIS_ONNX_MODEL_PATH environment variable."
62
- )
63
-
64
-
65
  def remove_background(
66
  image: PIL.Image.Image,
67
  rembg_session: Any = None,
68
  force: bool = False,
69
  **rembg_kwargs,
70
  ) -> PIL.Image.Image:
71
- ensure_dis_onnx_model()
72
- with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as temp:
73
- image.save(temp.name)
74
- extracted_img, mask = dis_remove_background(DIS_ONNX_MODEL_PATH, temp.name)
75
- # If extracted_img is a mask (single channel), use it as alpha for the original image
76
- if isinstance(extracted_img, np.ndarray):
77
- # If mask is float, convert to uint8
78
- if mask.dtype != np.uint8:
79
- mask = (np.clip(mask, 0, 1) * 255).astype(np.uint8)
80
- # Ensure mask is 2D
81
- if mask.ndim == 3:
82
- mask = mask[..., 0]
83
- # Convert original image to RGBA
84
- image = image.convert("RGBA")
85
- image_np = np.array(image)
86
- image_np[..., 3] = mask
87
- return Image.fromarray(image_np)
88
- # If extracted_img is already a color image, just return it
89
- return extracted_img
90
 
91
  def do_resize_content(original_image: Image, scale_rate):
92
  # resize image content wile retain the original image size
@@ -118,9 +88,7 @@ def preprocess_image(image, background_choice, foreground_ratio, backgroud_color
118
  background = Image.new("RGBA", image.size, (0, 0, 0, 0))
119
  image = Image.alpha_composite(background, image)
120
  else:
121
- image = remove_background(image, force=True)
122
- if image is None:
123
- raise gr.Error("Background removal failed. Please check the input image and ensure the model file exists and is valid.")
124
  image = do_resize_content(image, foreground_ratio)
125
  image = expand_to_square(image)
126
  image = add_background(image, backgroud_color)
@@ -154,20 +122,14 @@ parser.add_argument(
154
  help="config for stage2",
155
  )
156
 
157
- # Force CPU usage
158
- parser.add_argument("--device", type=str, default="cpu")
159
  args = parser.parse_args()
160
 
161
- if not torch.cuda.is_available():
162
- raise RuntimeError("CUDA is not available! Please check your GPU and CUDA installation.")
163
-
164
- device = torch.device("cuda")
165
-
166
  crm_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="CRM.pth")
167
  specs = json.load(open("configs/specs_objaverse_total.json"))
168
  model = CRM(specs)
169
- model.load_state_dict(torch.load(crm_path, map_location="cuda"), strict=False)
170
- model = model.to("cuda")
171
 
172
  stage1_config = OmegaConf.load(args.stage1_config).config
173
  stage2_config = OmegaConf.load(args.stage2_config).config
@@ -187,7 +149,7 @@ pipeline = TwoStagePipeline(
187
  stage2_model_config,
188
  stage1_sampler_config,
189
  stage2_sampler_config,
190
- device="cuda",
191
  dtype=torch.float32
192
  )
193
 
@@ -243,10 +205,8 @@ with gr.Blocks() as demo:
243
  image_output = gr.Image(interactive=False, label="Output RGB image")
244
  xyz_ouput = gr.Image(interactive=False, label="Output CCM image")
245
 
246
- output_model = gr.Model3D(
247
- label="Output OBJ",
248
- interactive=False,
249
- )
250
  gr.Markdown("Note: Ensure that the input image is correctly pre-processed into a grey background, otherwise the results will be unpredictable.")
251
 
252
  inputs = [
@@ -272,4 +232,4 @@ with gr.Blocks() as demo:
272
  inputs=inputs,
273
  outputs=outputs,
274
  )
275
- demo.queue().launch()
 
1
+ # Not ready to use yet
2
  import spaces
3
  import argparse
4
  import numpy as np
 
10
  from pipelines import TwoStagePipeline
11
  from huggingface_hub import hf_hub_download
12
  import os
13
+ import rembg
14
  from typing import Any
15
  import json
16
  import os
17
  import json
18
  import argparse
 
 
19
 
20
  from model import CRM
21
  from inference import generate3d
 
 
 
 
 
22
 
23
  pipeline = None
24
+ rembg_session = rembg.new_session()
25
 
26
 
27
  def expand_to_square(image, bg_color=(0, 0, 0, 0)):
 
40
  raise gr.Error("No image uploaded!")
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def remove_background(
44
  image: PIL.Image.Image,
45
  rembg_session: Any = None,
46
  force: bool = False,
47
  **rembg_kwargs,
48
  ) -> PIL.Image.Image:
49
+ do_remove = True
50
+ if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
51
+ # explain why current do not rm bg
52
+ print("alhpa channl not enpty, skip remove background, using alpha channel as mask")
53
+ background = Image.new("RGBA", image.size, (0, 0, 0, 0))
54
+ image = Image.alpha_composite(background, image)
55
+ do_remove = False
56
+ do_remove = do_remove or force
57
+ if do_remove:
58
+ image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
59
+ return image
 
 
 
 
 
 
 
 
60
 
61
  def do_resize_content(original_image: Image, scale_rate):
62
  # resize image content wile retain the original image size
 
88
  background = Image.new("RGBA", image.size, (0, 0, 0, 0))
89
  image = Image.alpha_composite(background, image)
90
  else:
91
+ image = remove_background(image, rembg_session, force=True)
 
 
92
  image = do_resize_content(image, foreground_ratio)
93
  image = expand_to_square(image)
94
  image = add_background(image, backgroud_color)
 
122
  help="config for stage2",
123
  )
124
 
125
+ parser.add_argument("--device", type=str, default="cuda")
 
126
  args = parser.parse_args()
127
 
 
 
 
 
 
128
  crm_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="CRM.pth")
129
  specs = json.load(open("configs/specs_objaverse_total.json"))
130
  model = CRM(specs)
131
+ model.load_state_dict(torch.load(crm_path, map_location="cpu"), strict=False)
132
+ model = model.to(args.device)
133
 
134
  stage1_config = OmegaConf.load(args.stage1_config).config
135
  stage2_config = OmegaConf.load(args.stage2_config).config
 
149
  stage2_model_config,
150
  stage1_sampler_config,
151
  stage2_sampler_config,
152
+ device=args.device,
153
  dtype=torch.float32
154
  )
155
 
 
205
  image_output = gr.Image(interactive=False, label="Output RGB image")
206
  xyz_ouput = gr.Image(interactive=False, label="Output CCM image")
207
 
208
+ output_model = gr.Model3D(label="Output GLB", clear_color=[1, 1, 1, 0])
209
+
 
 
210
  gr.Markdown("Note: Ensure that the input image is correctly pre-processed into a grey background, otherwise the results will be unpredictable.")
211
 
212
  inputs = [
 
232
  inputs=inputs,
233
  outputs=outputs,
234
  )
235
+ demo.queue().launch()
inference.py CHANGED
@@ -1,130 +1,39 @@
1
- import numpy as np
2
  import torch
3
- import time
4
- import nvdiffrast.torch as dr
5
- from util.utils import get_tri
6
- import tempfile
7
  from mesh import Mesh
8
- import zipfile
9
- from util.renderer import Renderer
10
- import trimesh
11
- import xatlas
12
- import cv2
13
- from PIL import Image, ImageFilter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- def vertex_color_to_uv_textured_glb(obj_path, glb_path, texture_size=512):
16
- mesh = trimesh.load(obj_path, process=False)
17
- vertex_colors = mesh.visual.vertex_colors[:, :3] # (N, 3), uint8
18
- # Generate UVs
19
- vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
20
- vertices = mesh.vertices[vmapping]
21
- vertex_colors = vertex_colors[vmapping]
22
- mesh.vertices = vertices
23
- mesh.faces = indices
24
- # Bake texture (hybrid: per-pixel barycentric for accuracy)
25
- buffer_size = texture_size * 2
26
- texture_buffer = np.zeros((buffer_size, buffer_size, 4), dtype=np.uint8)
27
- face_uvs = uvs[mesh.faces]
28
- face_colors = vertex_colors[mesh.faces]
29
- min_xy = np.floor(np.min(face_uvs, axis=1) * (buffer_size - 1)).astype(int)
30
- max_xy = np.ceil(np.max(face_uvs, axis=1) * (buffer_size - 1)).astype(int)
31
- for i in range(len(mesh.faces)):
32
- uv0, uv1, uv2 = face_uvs[i]
33
- c0, c1, c2 = face_colors[i]
34
- min_x, min_y = min_xy[i]
35
- max_x, max_y = max_xy[i]
36
- for y in range(min_y, max_y + 1):
37
- for x in range(min_x, max_x + 1):
38
- p = np.array([x + 0.5, y + 0.5]) / (buffer_size - 1)
39
- # Barycentric coordinates
40
- v0, v1, v2 = uv0, uv1, uv2
41
- denom = (v1[1] - v2[1]) * (v0[0] - v2[0]) + (v2[0] - v1[0]) * (v0[1] - v2[1])
42
- if denom == 0:
43
- continue
44
- u = ((v1[1] - v2[1]) * (p[0] - v2[0]) + (v2[0] - v1[0]) * (p[1] - v2[1])) / denom
45
- v = ((v2[1] - v0[1]) * (p[0] - v2[0]) + (v0[0] - v2[0]) * (p[1] - v2[1])) / denom
46
- w = 1 - u - v
47
- if (u >= 0) and (v >= 0) and (w >= 0):
48
- color = u * c0 + v * c1 + w * c2
49
- texture_buffer[y, x, :3] = np.clip(color, 0, 255).astype(np.uint8)
50
- texture_buffer[y, x, 3] = 255
51
- # Inpainting, filtering, and downsampling (keep optimized)
52
- image_bgra = texture_buffer.copy()
53
- mask = (image_bgra[:, :, 3] == 0).astype(np.uint8) * 255
54
- image_bgr = cv2.cvtColor(image_bgra, cv2.COLOR_BGRA2BGR)
55
- inpainted_bgr = cv2.inpaint(image_bgr, mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
56
- inpainted_bgra = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2BGRA)
57
- texture_buffer = inpainted_bgra[::-1]
58
- image_texture = Image.fromarray(texture_buffer)
59
- image_texture = image_texture.filter(ImageFilter.MedianFilter(size=3))
60
- image_texture = image_texture.filter(ImageFilter.GaussianBlur(radius=1))
61
- image_texture = image_texture.resize((texture_size, texture_size), Image.LANCZOS)
62
- # Assign UVs and texture to mesh
63
- material = trimesh.visual.material.PBRMaterial(
64
- baseColorFactor=[1.0, 1.0, 1.0, 1.0],
65
- baseColorTexture=image_texture,
66
- metallicFactor=0.0,
67
- roughnessFactor=1.0,
68
- )
69
- visuals = trimesh.visual.TextureVisuals(uv=uvs, material=material)
70
- mesh.visual = visuals
71
- mesh.export(glb_path)
72
- image_texture.save("debug_texture.png")
73
 
74
- def generate3d(model, rgb, ccm, device=None):
75
- device = torch.device("cuda")
76
- model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
77
- scale=model.input.scale, geo_type = model.geo_type)
78
- color_tri = torch.from_numpy(rgb).to(device)/255
79
- xyz_tri = torch.from_numpy(ccm[:,:,(2,1,0)]).to(device)/255
80
- color = color_tri.permute(2,0,1)
81
- xyz = xyz_tri.permute(2,0,1)
82
- def get_imgs(color):
83
- color_list = []
84
- color_list.append(color[:,:,256*5:256*(1+5)])
85
- for i in range(0,5):
86
- color_list.append(color[:,:,256*i:256*(1+i)])
87
- return torch.stack(color_list, dim=0)
88
- triplane_color = get_imgs(color).permute(0,2,3,1).unsqueeze(0).to(device)
89
- color = get_imgs(color)
90
- xyz = get_imgs(xyz)
91
- color = get_tri(color, dim=0, blender= True, scale = 1).unsqueeze(0).to(device)
92
- xyz = get_tri(xyz, dim=0, blender= True, scale = 1, fix= True).unsqueeze(0).to(device)
93
- triplane = torch.cat([color,xyz],dim=1).to(device)
94
- model.eval()
95
- if model.denoising == True:
96
- tnew = 20
97
- tnew = torch.randint(tnew, tnew+1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
98
- noise_new = torch.randn_like(triplane) *0.5+0.5
99
- triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
100
- start_time = time.time()
101
- with torch.no_grad():
102
- triplane_feature2 = model.unet2(triplane,tnew)
103
- end_time = time.time()
104
- elapsed_time = end_time - start_time
105
- print(f"unet takes {elapsed_time}s")
106
- else:
107
- triplane_feature2 = model.unet2(triplane)
108
- with torch.no_grad():
109
- data_config = {
110
- 'resolution': [1024, 1024],
111
- "triview_color": triplane_color.to(device),
112
- }
113
- verts, faces = model.decode(data_config, triplane_feature2)
114
- data_config['verts'] = verts[0]
115
- data_config['faces'] = faces
116
- from kiui.mesh_utils import clean_mesh
117
- verts, faces = clean_mesh(data_config['verts'].squeeze().cpu().numpy().astype(np.float32), data_config['faces'].squeeze().cpu().numpy().astype(np.int32), repair = False, remesh=True, remesh_size=0.005, remesh_iters=1)
118
- data_config['verts'] = torch.from_numpy(verts).to(device).contiguous()
119
- data_config['faces'] = torch.from_numpy(faces).to(device).contiguous()
120
- start_time = time.time()
121
- with torch.no_grad():
122
- mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
123
- model.export_mesh(data_config, mesh_path_glb, tri_fea_2 = triplane_feature2)
124
- end_time = time.time()
125
- elapsed_time = end_time - start_time
126
- print(f"uv takes {elapsed_time}s")
127
- obj_path = mesh_path_glb + ".obj"
128
- glb_path = mesh_path_glb + ".glb"
129
- vertex_color_to_uv_textured_glb(obj_path, glb_path)
130
- return glb_path
 
1
+ import os
2
  import torch
3
+ import numpy as np
4
+ from PIL import Image
 
 
5
  from mesh import Mesh
6
+ from pipelines.pipeline_text_to_3d import TextTo3D
7
+
8
+
9
+ # === Load Model (assumes this is done once at startup, not per request) ===
10
+ model = TextTo3D.from_pretrained("./checkpoints/zeroscope_v1_5")
11
+ model.to(torch.device("cpu"))
12
+ model.eval()
13
+
14
+ def generate3d(prompt: str, guidance_scale: float = 15.0, steps: int = 50) -> str:
15
+ # === Set up paths ===
16
+ output_dir = "outputs"
17
+ os.makedirs(output_dir, exist_ok=True)
18
+ base_name = prompt.replace(" ", "_").lower()
19
+ mesh_path_base = os.path.join(output_dir, base_name)
20
+
21
+ # === Generate 3D Mesh ===
22
+ mesh = model(prompt, guidance_scale=guidance_scale, steps=steps)
23
+ obj_path = mesh_path_base + ".obj"
24
+ mesh.export_mesh_wt_uv(obj_path)
25
+
26
+ # === Convert to GLB with textures ===
27
+ mesh_loaded = Mesh.load(obj_path, device=torch.device("cpu"))
28
+ glb_path = mesh_path_base + ".glb"
29
+ mesh_loaded.write(glb_path)
30
+
31
+ # === Return GLB path for Gradio display ===
32
+ return glb_path
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ if __name__ == "__main__":
36
+ # Example run
37
+ prompt = "a modern wooden chair"
38
+ output_glb = generate3d(prompt)
39
+ print(f"Generated GLB: {output_glb}")