import gradio as gr import os from PIL import Image import torch from diffusers.utils import load_image, check_min_version from pipeline_objectclear import ObjectClearPipeline from tools.download_util import load_file_from_url from tools.painter import mask_painter import argparse from safetensors.torch import load_file from model import CLIPImageEncoder, PostfuseModule import numpy as np import torchvision.transforms.functional as TF from scipy.ndimage import convolve, zoom import cv2 import time from huggingface_hub import hf_hub_download import spaces from tools.interact_tools import SamControler from tools.misc import get_device import json check_min_version("0.30.2") def parse_augment(): parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default=None) parser.add_argument('--sam_model_type', type=str, default="vit_h") parser.add_argument('--port', type=int, default=8000, help="only useful when running gradio applications") args = parser.parse_args() if not args.device: args.device = str(get_device()) return args def pad_to_multiple(image: np.ndarray, multiple: int = 8): h, w = image.shape[:2] pad_h = (multiple - h % multiple) % multiple pad_w = (multiple - w % multiple) % multiple if image.ndim == 3: padded = np.pad(image, ((0, pad_h), (0, pad_w), (0,0)), mode='reflect') else: padded = np.pad(image, ((0, pad_h), (0, pad_w)), mode='reflect') return padded, h, w def crop_to_original(image: np.ndarray, h: int, w: int): return image[:h, :w] def wavelet_blur_np(image: np.ndarray, radius: int): kernel = np.array([ [0.0625, 0.125, 0.0625], [0.125, 0.25, 0.125], [0.0625, 0.125, 0.0625] ], dtype=np.float32) blurred = np.empty_like(image) for c in range(image.shape[0]): blurred_c = convolve(image[c], kernel, mode='nearest') if radius > 1: blurred_c = zoom(zoom(blurred_c, 1 / radius, order=1), radius, order=1) blurred[c] = blurred_c return blurred def wavelet_decomposition_np(image: np.ndarray, levels=5): high_freq = np.zeros_like(image) for i in range(levels): radius = 2 ** i low_freq = wavelet_blur_np(image, radius) high_freq += (image - low_freq) image = low_freq return high_freq, low_freq def wavelet_reconstruction_np(content_feat: np.ndarray, style_feat: np.ndarray): content_high, _ = wavelet_decomposition_np(content_feat) _, style_low = wavelet_decomposition_np(style_feat) return content_high + style_low def wavelet_color_fix_np(fused: np.ndarray, mask: np.ndarray) -> np.ndarray: fused_np = fused.astype(np.float32) / 255.0 mask_np = mask.astype(np.float32) / 255.0 fused_np = fused_np.transpose(2, 0, 1) mask_np = mask_np.transpose(2, 0, 1) result_np = wavelet_reconstruction_np(fused_np, mask_np) result_np = result_np.transpose(1, 2, 0) result_np = np.clip(result_np * 255.0, 0, 255).astype(np.uint8) return result_np def fuse_with_wavelet(ori: np.ndarray, removed: np.ndarray, attn_map: np.ndarray, multiple: int = 8): H, W = ori.shape[:2] attn_map = attn_map.astype(np.float32) _, attn_map = cv2.threshold(attn_map, 128, 255, cv2.THRESH_BINARY) am = attn_map.astype(np.float32) am = am/255.0 am_up = cv2.resize(am, (W, H), interpolation=cv2.INTER_NEAREST) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (21,21)) am_d = cv2.dilate(am_up, kernel, iterations=1) am_d = cv2.GaussianBlur(am_d.astype(np.float32), (9,9), sigmaX=2) am_merged = np.maximum(am_up, am_d) am_merged = np.clip(am_merged, 0, 1) attn_up_3c = np.stack([am_merged]*3, axis=-1) attn_up_ori_3c = np.stack([am_up]*3, axis=-1) ori_out = ori * (1 - attn_up_ori_3c) rem_out = removed * (1 - attn_up_ori_3c) ori_pad, h0, w0 = pad_to_multiple(ori_out, multiple) rem_pad, _, _ = pad_to_multiple(rem_out, multiple) wave_rgb = wavelet_color_fix_np(ori_pad, rem_pad) wave = crop_to_original(wave_rgb, h0, w0) # fusion fused = (wave * (1 - attn_up_3c) + removed * attn_up_3c).astype(np.uint8) return fused def resize_by_short_side(image, target_short=512, resample=Image.BICUBIC): w, h = image.size if w < h: new_w = target_short new_h = int(h * target_short / w) new_h = (new_h + 15) // 16 * 16 else: new_h = target_short new_w = int(w * target_short / h) new_w = (new_w + 15) // 16 * 16 return image.resize((new_w, new_h), resample=resample) # convert points input to prompt state def get_prompt(click_state, click_input): inputs = json.loads(click_input) points = click_state[0] labels = click_state[1] for input in inputs: points.append(input[:2]) labels.append(input[2]) click_state[0] = points click_state[1] = labels prompt = { "prompt_type":["click"], "input_point":click_state[0], "input_label":click_state[1], "multimask_output":"True", } return prompt # use sam to get the mask @spaces.GPU def sam_refine(image_state, point_prompt, click_state, evt:gr.SelectData): if point_prompt == "Positive": coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1]) else: coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1]) # prompt for sam model model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(image_state["origin_image"]) prompt = get_prompt(click_state=click_state, click_input=coordinate) mask, logit, painted_image = model.first_frame_click( image=image_state["origin_image"], points=np.array(prompt["input_point"]), labels=np.array(prompt["input_label"]), multimask=prompt["multimask_output"], ) image_state["mask"] = mask image_state["logit"] = logit image_state["painted_image"] = painted_image return painted_image, image_state, click_state def add_multi_mask(image_state, interactive_state, mask_dropdown): mask = image_state["mask"] interactive_state["masks"].append(mask) interactive_state["mask_names"].append("mask_{:03d}".format(len(interactive_state["masks"]))) mask_dropdown.append("mask_{:03d}".format(len(interactive_state["masks"]))) select_frame = show_mask(image_state, interactive_state, mask_dropdown) return interactive_state, gr.update(choices=interactive_state["mask_names"], value=mask_dropdown), select_frame, [[],[]] def clear_click(image_state, click_state): click_state = [[],[]] input_image = image_state["origin_image"] return input_image, click_state def remove_multi_mask(interactive_state, click_state, image_state): interactive_state["mask_names"]= [] interactive_state["masks"] = [] click_state = [[],[]] input_image = image_state["origin_image"] return interactive_state, gr.update(choices=[],value=[]), input_image, click_state def show_mask(image_state, interactive_state, mask_dropdown): mask_dropdown.sort() if image_state["origin_image"] is not None: select_frame = image_state["origin_image"] for i in range(len(mask_dropdown)): mask_number = int(mask_dropdown[i].split("_")[1]) - 1 mask = interactive_state["masks"][mask_number] select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2) return select_frame @spaces.GPU def upload_and_reset(image_input, interactive_state): click_state = [[], []] interactive_state["mask_names"]= [] interactive_state["masks"] = [] image_state, image_info, image_input = update_image_state_on_upload(image_input) return ( image_state, image_info, image_input, interactive_state, click_state, gr.update(choices=[], value=[]), ) def update_image_state_on_upload(image_input): frame = image_input image_size = (frame.size[1], frame.size[0]) frame_np = np.array(frame) image_state = { "origin_image": frame_np, "painted_image": frame_np.copy(), "mask": np.zeros((image_size[0], image_size[1]), np.uint8), "logit": None, } image_info = f"Image Name: uploaded.png,\nImage Size: {image_size}" model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(frame_np) return image_state, image_info, image_input # SAM generator class MaskGenerator(): def __init__(self, sam_checkpoint, args): self.args = args self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device) def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True): mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask) return mask, logit, painted_image # args, defined in track_anything.py args = parse_augment() sam_checkpoint_url_dict = { 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" } checkpoint_folder = os.path.join('/home/user/app/', 'pretrained_models') sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type], checkpoint_folder) # initialize sams model = MaskGenerator(sam_checkpoint, args) # Build pipeline device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") pipe = ObjectClearPipeline.from_pretrained_with_custom_modules( "jixin0101/ObjectClear", torch_dtype=torch.float16, save_cross_attn=True, cache_dir="/home/jovyan/shared/jixinzhao/models", ) pipe.to(device) @spaces.GPU def process(image_state, interactive_state, mask_dropdown, guidance_scale, seed, num_inference_steps, strength ): generator = torch.Generator(device="cuda").manual_seed(seed) image_np = image_state["origin_image"] image = Image.fromarray(image_np) if interactive_state["masks"]: if len(mask_dropdown) == 0: mask_dropdown = ["mask_001"] mask_dropdown.sort() template_mask = interactive_state["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1])) for i in range(1,len(mask_dropdown)): mask_number = int(mask_dropdown[i].split("_")[1]) - 1 template_mask = np.clip(template_mask+interactive_state["masks"][mask_number]*(mask_number+1), 0, mask_number+1) image_state["mask"]= template_mask else: template_mask = image_state["mask"] mask = Image.fromarray((template_mask).astype(np.uint8) * 255) image_or = image.copy() image = image.convert("RGB") mask = mask.convert("RGB") image = resize_by_short_side(image, 512, resample=Image.BICUBIC) mask = resize_by_short_side(mask, 512, resample=Image.NEAREST) w, h = image.size result = pipe( prompt="remove the instance of object", image=image, mask_image=mask, generator=generator, num_inference_steps=num_inference_steps, strength=strength, guidance_scale=guidance_scale, height=h, width=w, ) inpainted_img = result[0].images[0] attn_map = result[1] attn_np = attn_map.mean(dim=1)[0].cpu().numpy() * 255. fused_img = fuse_with_wavelet(np.array(image), np.array(inpainted_img), attn_np) fused_img_pil = Image.fromarray(fused_img.astype(np.uint8)) return fused_img_pil.resize((image_or.size[:2])), (image.resize((image_or.size[:2])), fused_img_pil.resize((image_or.size[:2]))) import base64 with open("./Logo.png", "rb") as f: img_bytes = f.read() img_b64 = base64.b64encode(img_bytes).decode() html_img = f'''
''' tutorial_url = "https://github.com/zjx0101/ObjectClear/releases/download/media/tutorial.mp4" assets_path = os.path.join('/home/user/app/hugging_face/', "assets/") load_file_from_url(tutorial_url, assets_path) description = r""" Official Gradio demo for ObjectClear: Complete Object Removal via Object-Effect Attention.
🔥 ObjectClear is an object removal model that can jointly eliminate the target object and its associated effects leveraging Object-Effect Attention, while preserving background consistency.
🖼️ Try to drop your image, assign the target masks with a few clicks, and get the object removal results!
*Note: Due to online GPU memory constraints, all input images will be resized during inference so that the shortest side is 512 pixels.
* """ article = r"""

If ObjectClear is helpful, please help to star the Github Repo. Thanks!


📑 **Citation**
If our work is useful for your research, please consider citing: ```bibtex @InProceedings{zhao2025ObjectClear, title = {{ObjectClear}: Complete Object Removal via Object-Effect Attention}, author = {Zhao, Jixin and Zhou, Shangchen and Wang, Zhouxia and Yang, Peiqing and Loy, Chen Change}, booktitle = {arXiv preprint arXiv:2505.22636}, year = {2025} } ``` 📧 **Contact**
If you have any questions, please feel free to reach me out at jixinzhao0101@gmail.com.
👏 **Acknowledgement**
This demo is adapted from [MatAnyone](https://github.com/pq-yang/MatAnyone), and leveraging segmentation capabilities from [Segment Anything](https://github.com/facebookresearch/segment-anything). Thanks for their awesome works! """ custom_css = """ #input-image { aspect-ratio: 1 / 1; width: 100%; max-width: 100%; height: auto; display: flex; align-items: center; justify-content: center; } #input-image img { max-width: 100%; max-height: 100%; object-fit: contain; display: block; } #main-columns { gap: 60px; } #main-columns > .gr-column { flex: 1; } #compare-image { width: 100%; aspect-ratio: 1 / 1; display: flex; align-items: center; justify-content: center; margin: 0; padding: 0; max-width: 100%; box-sizing: border-box; } #compare-image svg.svelte-zyxd38 { position: absolute !important; top: 50% !important; left: 50% !important; transform: translate(-50%, -50%) !important; } #compare-image .icon.svelte-1oiin9d { position: absolute; top: 50%; left: 50%; transform: translate(-50%, -50%); } #compare-image { position: relative; overflow: hidden; } .new_button {background-color: #171717 !important; color: #ffffff !important; border: none !important;} .new_button:hover {background-color: #4b4b4b !important;} #start-button { background: linear-gradient(135deg, #2575fc 0%, #6a11cb 100%); color: white; border: none; padding: 12px 24px; font-size: 16px; font-weight: bold; border-radius: 12px; cursor: pointer; box-shadow: 0 0 12px rgba(100, 100, 255, 0.7); transition: all 0.3s ease; } #start-button:hover { transform: scale(1.05); box-shadow: 0 0 20px rgba(100, 100, 255, 1); } """ with gr.Blocks(css=custom_css) as demo: gr.HTML(html_img) gr.Markdown(description) with gr.Group(elem_classes="gr-monochrome-group", visible=True): with gr.Row(): with gr.Accordion('SAM Settings (click to expand)', open=False): with gr.Row(): point_prompt = gr.Radio( choices=["Positive", "Negative"], value="Positive", label="Point Prompt", info="Click to add positive or negative point for target mask", interactive=True, min_width=100, scale=1) mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask Selection", info="Choose 1~all mask(s) added in Step 2") with gr.Row(elem_id="main-columns"): with gr.Column(): click_state = gr.State([[],[]]) interactive_state = gr.State( { "mask_names": [], "masks": [] } ) image_state = gr.State( { "origin_image": None, "painted_image": None, "mask": None, "logit": None } ) image_info = gr.Textbox(label="Image Info", visible=False) input_image = gr.Image( label='Input', type='pil', sources=["upload"], image_mode='RGB', interactive=True, elem_id="input-image" ) with gr.Row(equal_height=True, elem_classes="mask_button_group"): clear_button_click = gr.Button(value="Clear Clicks",elem_classes="new_button", min_width=100) add_mask_button = gr.Button(value="Add Mask", elem_classes="new_button", min_width=100) remove_mask_button = gr.Button(value="Delete Mask", elem_classes="new_button", min_width=100) submit_button_component = gr.Button( value='Start ObjectClear', elem_id="start-button" ) with gr.Accordion('ObjectClear Settings', open=True): strength = gr.Radio( choices=[0.99, 1.0], value=0.99, label="Strength", info="0.99 better preserves the background and color; use 1.0 if object/shadow is not fully removed (default: 0.99)" ) guidance_scale = gr.Slider( minimum=1, maximum=10, step=0.5, value=2.5, label="Guidance Scale", info="Higher = stronger removal; lower = better background preservation (default: 2.5)" ) seed = gr.Slider( minimum=0, maximum=1000000, step=1, value=300000, label="Seed Value", info="Different seeds can lead to noticeably different object removal results (default: 300000)" ) num_inference_steps = gr.Slider( minimum=1, maximum=40, step=1, value=20, label="Num Inference Steps", info="Higher values may improve quality but take longer (default: 20)" ) with gr.Column(): output_image_component = gr.Image( type='pil', image_mode='RGB', label='Output', format="png", elem_id="input-image") output_compare_image_component = gr.ImageSlider( label="Comparison", type="pil", format='png', elem_id="compare-image" ) input_image.upload( fn=upload_and_reset, inputs=[input_image, interactive_state], outputs=[ image_state, image_info, input_image, interactive_state, click_state, mask_dropdown, ] ) # click select image to get mask using sam input_image.select( fn=sam_refine, inputs=[image_state, point_prompt, click_state], outputs=[input_image, image_state, click_state] ) # add different mask add_mask_button.click( fn=add_multi_mask, inputs=[image_state, interactive_state, mask_dropdown], outputs=[interactive_state, mask_dropdown, input_image, click_state] ) remove_mask_button.click( fn=remove_multi_mask, inputs=[interactive_state, click_state, image_state], outputs=[interactive_state, mask_dropdown, input_image, click_state] ) # points clear clear_button_click.click( fn = clear_click, inputs = [image_state, click_state,], outputs = [input_image, click_state], ) submit_button_component.click( fn=process, inputs=[ image_state, interactive_state, mask_dropdown, guidance_scale, seed, num_inference_steps, strength ], outputs=[ output_image_component, output_compare_image_component ] ) with gr.Accordion("📕 Video Tutorial (click to expand)", open=False, elem_classes="custom-bg"): with gr.Row(): gr.Video(value="/home/user/app/hugging_face/assets/tutorial.mp4", elem_classes="video") gr.Markdown("---") gr.Markdown("## Examples") example_images = [ os.path.join(os.path.dirname(__file__), "examples", f"test{i}.png") for i in range(10) ] examples_data = [ [example_images[i], None] for i in range(len(example_images)) ] examples = gr.Examples( examples=examples_data, inputs=[input_image, interactive_state], outputs=[image_state, image_info, input_image, interactive_state, click_state, mask_dropdown], fn=upload_and_reset, run_on_click=True, cache_examples=False, label="Click below to load example images" ) gr.Markdown(article) def pre_update_input_image(): return gr.update(value=None) demo.load( fn=pre_update_input_image, inputs=[], outputs=[input_image] ) demo.launch(debug=True, show_error=True)