Refactor remove_background function in app.py to enhance mask handling. Added checks for mask dimensions and data type, ensuring proper conversion to RGBA format for images. This improves the output quality when masks are applied.
fa6cc29
import spaces | |
import argparse | |
import numpy as np | |
import gradio as gr | |
from omegaconf import OmegaConf | |
import torch | |
from PIL import Image | |
import PIL | |
from pipelines import TwoStagePipeline | |
from huggingface_hub import hf_hub_download | |
import os | |
from typing import Any | |
import json | |
import os | |
import json | |
import argparse | |
import requests | |
import tempfile | |
from model import CRM | |
from inference import generate3d | |
from dis_bg_remover import remove_background as dis_remove_background | |
# Configurable ONNX model path (can be set via environment variable) | |
DIS_ONNX_MODEL_PATH = os.environ.get("DIS_ONNX_MODEL_PATH", "isnet_dis.onnx") | |
DIS_ONNX_MODEL_URL = "https://huggingface.co/stoned0651/isnet_dis.onnx/resolve/main/isnet_dis.onnx" | |
pipeline = None | |
def expand_to_square(image, bg_color=(0, 0, 0, 0)): | |
# expand image to 1:1 | |
width, height = image.size | |
if width == height: | |
return image | |
new_size = (max(width, height), max(width, height)) | |
new_image = Image.new("RGBA", new_size, bg_color) | |
paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2) | |
new_image.paste(image, paste_position) | |
return new_image | |
def check_input_image(input_image): | |
if input_image is None: | |
raise gr.Error("No image uploaded!") | |
def ensure_dis_onnx_model(): | |
if not os.path.exists(DIS_ONNX_MODEL_PATH): | |
try: | |
print(f"Model file not found at {DIS_ONNX_MODEL_PATH}. Downloading from {DIS_ONNX_MODEL_URL}...") | |
response = requests.get(DIS_ONNX_MODEL_URL, stream=True) | |
response.raise_for_status() | |
with open(DIS_ONNX_MODEL_PATH, "wb") as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
if chunk: | |
f.write(chunk) | |
print(f"Downloaded model to {DIS_ONNX_MODEL_PATH}") | |
except Exception as e: | |
raise gr.Error( | |
f"Failed to download DIS background remover model file: {e}\n" | |
f"Please manually download it from {DIS_ONNX_MODEL_URL} and place it in the project directory or set the DIS_ONNX_MODEL_PATH environment variable." | |
) | |
def remove_background( | |
image: PIL.Image.Image, | |
rembg_session: Any = None, | |
force: bool = False, | |
**rembg_kwargs, | |
) -> PIL.Image.Image: | |
ensure_dis_onnx_model() | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as temp: | |
image.save(temp.name) | |
extracted_img, mask = dis_remove_background(DIS_ONNX_MODEL_PATH, temp.name) | |
# If extracted_img is a mask (single channel), use it as alpha for the original image | |
if isinstance(extracted_img, np.ndarray): | |
# If mask is float, convert to uint8 | |
if mask.dtype != np.uint8: | |
mask = (np.clip(mask, 0, 1) * 255).astype(np.uint8) | |
# Ensure mask is 2D | |
if mask.ndim == 3: | |
mask = mask[..., 0] | |
# Convert original image to RGBA | |
image = image.convert("RGBA") | |
image_np = np.array(image) | |
image_np[..., 3] = mask | |
return Image.fromarray(image_np) | |
# If extracted_img is already a color image, just return it | |
return extracted_img | |
def do_resize_content(original_image: Image, scale_rate): | |
# resize image content wile retain the original image size | |
if scale_rate != 1: | |
# Calculate the new size after rescaling | |
new_size = tuple(int(dim * scale_rate) for dim in original_image.size) | |
# Resize the image while maintaining the aspect ratio | |
resized_image = original_image.resize(new_size) | |
# Create a new image with the original size and black background | |
padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0)) | |
paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2) | |
padded_image.paste(resized_image, paste_position) | |
return padded_image | |
else: | |
return original_image | |
def add_background(image, bg_color=(255, 255, 255)): | |
# given an RGBA image, alpha channel is used as mask to add background color | |
background = Image.new("RGBA", image.size, bg_color) | |
return Image.alpha_composite(background, image) | |
def preprocess_image(image, background_choice, foreground_ratio, backgroud_color): | |
""" | |
input image is a pil image in RGBA, return RGB image | |
""" | |
print(background_choice) | |
if background_choice == "Alpha as mask": | |
background = Image.new("RGBA", image.size, (0, 0, 0, 0)) | |
image = Image.alpha_composite(background, image) | |
else: | |
image = remove_background(image, force=True) | |
if image is None: | |
raise gr.Error("Background removal failed. Please check the input image and ensure the model file exists and is valid.") | |
image = do_resize_content(image, foreground_ratio) | |
image = expand_to_square(image) | |
image = add_background(image, backgroud_color) | |
return image.convert("RGB") | |
def gen_image(input_image, seed, scale, step): | |
global pipeline, model, args | |
pipeline.set_seed(seed) | |
rt_dict = pipeline(input_image, scale=scale, step=step) | |
stage1_images = rt_dict["stage1_images"] | |
stage2_images = rt_dict["stage2_images"] | |
np_imgs = np.concatenate(stage1_images, 1) | |
np_xyzs = np.concatenate(stage2_images, 1) | |
glb_path = generate3d(model, np_imgs, np_xyzs, args.device) | |
return Image.fromarray(np_imgs), Image.fromarray(np_xyzs), glb_path#, obj_path | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--stage1_config", | |
type=str, | |
default="configs/nf7_v3_SNR_rd_size_stroke.yaml", | |
help="config for stage1", | |
) | |
parser.add_argument( | |
"--stage2_config", | |
type=str, | |
default="configs/stage2-v2-snr.yaml", | |
help="config for stage2", | |
) | |
parser.add_argument("--device", type=str, default="cuda") | |
args = parser.parse_args() | |
crm_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="CRM.pth") | |
specs = json.load(open("configs/specs_objaverse_total.json")) | |
model = CRM(specs) | |
model.load_state_dict(torch.load(crm_path, map_location="cpu"), strict=False) | |
model = model.to(args.device) | |
stage1_config = OmegaConf.load(args.stage1_config).config | |
stage2_config = OmegaConf.load(args.stage2_config).config | |
stage2_sampler_config = stage2_config.sampler | |
stage1_sampler_config = stage1_config.sampler | |
stage1_model_config = stage1_config.models | |
stage2_model_config = stage2_config.models | |
xyz_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="ccm-diffusion.pth") | |
pixel_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth") | |
stage1_model_config.resume = pixel_path | |
stage2_model_config.resume = xyz_path | |
pipeline = TwoStagePipeline( | |
stage1_model_config, | |
stage2_model_config, | |
stage1_sampler_config, | |
stage2_sampler_config, | |
device=args.device, | |
dtype=torch.float32 | |
) | |
_DESCRIPTION = ''' | |
* Our [official implementation](https://github.com/thu-ml/CRM) uses UV texture instead of vertex color. It has better texture than this online demo. | |
* Project page of CRM: https://ml.cs.tsinghua.edu.cn/~zhengyi/CRM/ | |
* If you find the output unsatisfying, try using different seeds:) | |
''' | |
with gr.Blocks() as demo: | |
gr.Markdown("# CRM: Single Image to 3D Textured Mesh with Convolutional Reconstruction Model") | |
gr.Markdown(_DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
image_input = gr.Image( | |
label="Image input", | |
image_mode="RGBA", | |
sources="upload", | |
type="pil", | |
) | |
processed_image = gr.Image(label="Processed Image", interactive=False, type="pil", image_mode="RGB") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
background_choice = gr.Radio([ | |
"Alpha as mask", | |
"Auto Remove background" | |
], value="Auto Remove background", | |
label="backgroud choice") | |
# do_remove_background = gr.Checkbox(label=, value=True) | |
# force_remove = gr.Checkbox(label=, value=False) | |
back_groud_color = gr.ColorPicker(label="Background Color", value="#7F7F7F", interactive=False) | |
foreground_ratio = gr.Slider( | |
label="Foreground Ratio", | |
minimum=0.5, | |
maximum=1.0, | |
value=1.0, | |
step=0.05, | |
) | |
with gr.Column(): | |
seed = gr.Number(value=1234, label="seed", precision=0) | |
guidance_scale = gr.Number(value=5.5, minimum=3, maximum=10, label="guidance_scale") | |
step = gr.Number(value=30, minimum=30, maximum=100, label="sample steps", precision=0) | |
text_button = gr.Button("Generate 3D shape") | |
gr.Examples( | |
examples=[os.path.join("examples", i) for i in os.listdir("examples")], | |
inputs=[image_input], | |
examples_per_page = 20, | |
) | |
with gr.Column(): | |
image_output = gr.Image(interactive=False, label="Output RGB image") | |
xyz_ouput = gr.Image(interactive=False, label="Output CCM image") | |
output_model = gr.Model3D( | |
label="Output OBJ", | |
interactive=False, | |
) | |
gr.Markdown("Note: Ensure that the input image is correctly pre-processed into a grey background, otherwise the results will be unpredictable.") | |
inputs = [ | |
processed_image, | |
seed, | |
guidance_scale, | |
step, | |
] | |
outputs = [ | |
image_output, | |
xyz_ouput, | |
output_model, | |
# output_obj, | |
] | |
text_button.click(fn=check_input_image, inputs=[image_input]).success( | |
fn=preprocess_image, | |
inputs=[image_input, background_choice, foreground_ratio, back_groud_color], | |
outputs=[processed_image], | |
).success( | |
fn=gen_image, | |
inputs=inputs, | |
outputs=outputs, | |
) | |
demo.queue().launch() | |