ObjectClear / app.py
jixin0101's picture
docs: clarify image resizing to 'shorter side = 512px'
a350cc0
import gradio as gr
import os
from PIL import Image
import torch
from diffusers.utils import check_min_version
from pipeline_objectclear import ObjectClearPipeline
from tools.download_util import load_file_from_url
from tools.painter import mask_painter
import argparse
import numpy as np
import torchvision.transforms.functional as TF
from scipy.ndimage import convolve, zoom
import spaces
from utils import resize_by_short_side
from tools.interact_tools import SamControler
from tools.misc import get_device
import json
check_min_version("0.30.2")
def parse_augment():
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default=None)
parser.add_argument('--sam_model_type', type=str, default="vit_h")
parser.add_argument('--port', type=int, default=8000, help="only useful when running gradio applications")
args = parser.parse_args()
if not args.device:
args.device = str(get_device())
return args
# convert points input to prompt state
def get_prompt(click_state, click_input):
inputs = json.loads(click_input)
points = click_state[0]
labels = click_state[1]
for input in inputs:
points.append(input[:2])
labels.append(input[2])
click_state[0] = points
click_state[1] = labels
prompt = {
"prompt_type":["click"],
"input_point":click_state[0],
"input_label":click_state[1],
"multimask_output":"True",
}
return prompt
# use sam to get the mask
@spaces.GPU
def sam_refine(image_state, point_prompt, click_state, evt:gr.SelectData):
if point_prompt == "Positive":
coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
else:
coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
# prompt for sam model
model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(image_state["origin_image"])
prompt = get_prompt(click_state=click_state, click_input=coordinate)
mask, logit, painted_image = model.first_frame_click(
image=image_state["origin_image"],
points=np.array(prompt["input_point"]),
labels=np.array(prompt["input_label"]),
multimask=prompt["multimask_output"],
)
image_state["mask"] = mask
image_state["logit"] = logit
image_state["painted_image"] = painted_image
return painted_image, image_state, click_state
def add_multi_mask(image_state, interactive_state, mask_dropdown):
mask = image_state["mask"]
interactive_state["masks"].append(mask)
interactive_state["mask_names"].append("mask_{:03d}".format(len(interactive_state["masks"])))
mask_dropdown.append("mask_{:03d}".format(len(interactive_state["masks"])))
select_frame = show_mask(image_state, interactive_state, mask_dropdown)
return interactive_state, gr.update(choices=interactive_state["mask_names"], value=mask_dropdown), select_frame, [[],[]]
def clear_click(image_state, click_state):
click_state = [[],[]]
input_image = image_state["origin_image"]
return input_image, click_state
def remove_multi_mask(interactive_state, click_state, image_state):
interactive_state["mask_names"]= []
interactive_state["masks"] = []
click_state = [[],[]]
input_image = image_state["origin_image"]
return interactive_state, gr.update(choices=[],value=[]), input_image, click_state
def show_mask(image_state, interactive_state, mask_dropdown):
mask_dropdown.sort()
if image_state["origin_image"] is not None:
select_frame = image_state["origin_image"]
for i in range(len(mask_dropdown)):
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
mask = interactive_state["masks"][mask_number]
select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
return select_frame
@spaces.GPU
def upload_and_reset(image_input, interactive_state):
click_state = [[], []]
interactive_state["mask_names"]= []
interactive_state["masks"] = []
image_state, image_info, image_input = update_image_state_on_upload(image_input)
return (
image_state,
image_info,
image_input,
interactive_state,
click_state,
gr.update(choices=[], value=[]),
)
def update_image_state_on_upload(image_input):
frame = image_input
image_size = (frame.size[1], frame.size[0])
frame_np = np.array(frame)
image_state = {
"origin_image": frame_np,
"painted_image": frame_np.copy(),
"mask": np.zeros((image_size[0], image_size[1]), np.uint8),
"logit": None,
}
image_info = f"Image Name: uploaded.png,\nImage Size: {image_size}"
model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(frame_np)
return image_state, image_info, image_input
# SAM generator
class MaskGenerator():
def __init__(self, sam_checkpoint, args):
self.args = args
self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
return mask, logit, painted_image
# args, defined in track_anything.py
args = parse_augment()
sam_checkpoint_url_dict = {
'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
}
checkpoint_folder = os.path.join('/home/user/app/', 'pretrained_models')
sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type], checkpoint_folder)
# initialize sams
model = MaskGenerator(sam_checkpoint, args)
# Build pipeline
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pipe = ObjectClearPipeline.from_pretrained_with_custom_modules(
"jixin0101/ObjectClear",
torch_dtype=torch.float16,
variant='fp16',
apply_attention_guided_fusion=True
)
pipe.to(device)
@spaces.GPU
def process(image_state, interactive_state, mask_dropdown, guidance_scale, seed, num_inference_steps, strength
):
generator = torch.Generator(device="cuda").manual_seed(seed)
image_np = image_state["origin_image"]
image = Image.fromarray(image_np)
if interactive_state["masks"]:
if len(mask_dropdown) == 0:
mask_dropdown = ["mask_001"]
mask_dropdown.sort()
template_mask = interactive_state["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
for i in range(1,len(mask_dropdown)):
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
template_mask = np.clip(template_mask+interactive_state["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
image_state["mask"]= template_mask
else:
template_mask = image_state["mask"]
mask = Image.fromarray((template_mask).astype(np.uint8) * 255)
image_or = image.copy()
image = image.convert("RGB")
mask = mask.convert("RGB")
image = resize_by_short_side(image, 512, resample=Image.BICUBIC)
mask = resize_by_short_side(mask, 512, resample=Image.NEAREST)
w, h = image.size
result = pipe(
prompt="remove the instance of object",
image=image,
mask_image=mask,
generator=generator,
num_inference_steps=num_inference_steps,
strength=strength,
guidance_scale=guidance_scale,
height=h,
width=w,
)
fused_img_pil = result.images[0]
return fused_img_pil.resize((image_or.size[:2])), (image.resize((image_or.size[:2])), fused_img_pil.resize((image_or.size[:2])))
import base64
with open("./Logo.png", "rb") as f:
img_bytes = f.read()
img_b64 = base64.b64encode(img_bytes).decode()
html_img = f'''
<div style="display:flex; justify-content:center; align-items:center; width:100%;">
<img src="data:image/png;base64,{img_b64}" style="border:none; width:200px; height:auto;"/>
</div>
'''
tutorial_url = "https://github.com/zjx0101/ObjectClear/releases/download/media/tutorial.mp4"
assets_path = os.path.join('/home/user/app/hugging_face/', "assets/")
load_file_from_url(tutorial_url, assets_path)
description = r"""
<b>Official Gradio demo</b> for <a href='https://github.com/zjx0101/ObjectClear' target='_blank'><b>ObjectClear: Complete Object Removal via Object-Effect Attention</b></a>.<br>
🔥 ObjectClear is an object removal model that can jointly eliminate the target object and its associated effects leveraging Object-Effect Attention, while preserving background consistency.<br>
🖼️ Try to drop your image, assign the target masks with a few clicks, and get the object removal results!<br>
*Note: All input images are temporarily resized (shorter side = 512 pixels) during inference to match the training resolution. Final outputs are restored to the original resolution.<br>*
"""
article = r"""<h3>
<b>If ObjectClear is helpful, please help to star the <a href='https://github.com/zjx0101/ObjectClear' target='_blank'>Github Repo</a>. Thanks!</b></h3>
<hr>
📑 **Citation**
<br>
If our work is useful for your research, please consider citing:
```bibtex
@InProceedings{zhao2025ObjectClear,
title = {{ObjectClear}: Complete Object Removal via Object-Effect Attention},
author = {Zhao, Jixin and Zhou, Shangchen and Wang, Zhouxia and Yang, Peiqing and Loy, Chen Change},
booktitle = {arXiv preprint arXiv:2505.22636},
year = {2025}
}
```
📧 **Contact**
<br>
If you have any questions, please feel free to reach me out at <b>[email protected]</b>.
<br>
👏 **Acknowledgement**
<br>
This demo is adapted from [MatAnyone](https://github.com/pq-yang/MatAnyone), and leveraging segmentation capabilities from [Segment Anything](https://github.com/facebookresearch/segment-anything). Thanks for their awesome works!
"""
custom_css = """
#input-image {
aspect-ratio: 1 / 1;
width: 100%;
max-width: 100%;
height: auto;
display: flex;
align-items: center;
justify-content: center;
}
#input-image img {
max-width: 100%;
max-height: 100%;
object-fit: contain;
display: block;
}
#main-columns {
gap: 60px;
}
#main-columns > .gr-column {
flex: 1;
}
#compare-image {
width: 100%;
aspect-ratio: 1 / 1;
display: flex;
align-items: center;
justify-content: center;
margin: 0;
padding: 0;
max-width: 100%;
box-sizing: border-box;
}
#compare-image svg.svelte-zyxd38 {
position: absolute !important;
top: 50% !important;
left: 50% !important;
transform: translate(-50%, -50%) !important;
}
#compare-image .icon.svelte-1oiin9d {
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
}
#compare-image {
position: relative;
overflow: hidden;
}
.new_button {background-color: #171717 !important; color: #ffffff !important; border: none !important;}
.new_button:hover {background-color: #4b4b4b !important;}
#start-button {
background: linear-gradient(135deg, #2575fc 0%, #6a11cb 100%);
color: white;
border: none;
padding: 12px 24px;
font-size: 16px;
font-weight: bold;
border-radius: 12px;
cursor: pointer;
box-shadow: 0 0 12px rgba(100, 100, 255, 0.7);
transition: all 0.3s ease;
}
#start-button:hover {
transform: scale(1.05);
box-shadow: 0 0 20px rgba(100, 100, 255, 1);
}
<style>
.button-wrapper {
width: 30%;
text-align: center;
}
.wide-button {
width: 83% !important;
background-color: black !important;
color: white !important;
border: none !important;
padding: 8px 0 !important;
font-size: 16px !important;
display: inline-block;
margin: 30px 0px 0px 50px ;
}
.wide-button:hover {
background-color: #656262 !important;
}
</style>
"""
with gr.Blocks(css=custom_css) as demo:
gr.HTML(html_img)
gr.Markdown(description)
with gr.Group(elem_classes="gr-monochrome-group", visible=True):
with gr.Row():
with gr.Accordion('SAM Settings (click to expand)', open=False):
with gr.Row():
point_prompt = gr.Radio(
choices=["Positive", "Negative"],
value="Positive",
label="Point Prompt",
info="Click to add positive or negative point for target mask",
interactive=True,
min_width=100,
scale=1)
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask Selection", info="Choose 1~all mask(s) added in Step 2")
with gr.Row(elem_id="main-columns"):
with gr.Column():
click_state = gr.State([[],[]])
interactive_state = gr.State(
{
"mask_names": [],
"masks": []
}
)
image_state = gr.State(
{
"origin_image": None,
"painted_image": None,
"mask": None,
"logit": None
}
)
image_info = gr.Textbox(label="Image Info", visible=False)
input_image = gr.Image(
label='Input',
type='pil',
sources=["upload"],
image_mode='RGB',
interactive=True,
elem_id="input-image"
)
with gr.Row(equal_height=True, elem_classes="mask_button_group"):
clear_button_click = gr.Button(value="Clear Clicks",elem_classes="new_button", min_width=100)
add_mask_button = gr.Button(value="Add Mask", elem_classes="new_button", min_width=100)
remove_mask_button = gr.Button(value="Delete Mask", elem_classes="new_button", min_width=100)
submit_button_component = gr.Button(
value='Start ObjectClear', elem_id="start-button"
)
with gr.Accordion('ObjectClear Settings', open=True):
strength = gr.Radio(
choices=[0.99, 1.0],
value=0.99,
label="Strength",
info="0.99 better preserves the background and color; use 1.0 if object/shadow is not fully removed (default: 0.99)"
)
guidance_scale = gr.Slider(
minimum=1, maximum=10, step=0.5, value=2.5,
label="Guidance Scale",
info="Higher = stronger removal; lower = better background preservation (default: 2.5)"
)
seed = gr.Slider(
minimum=0, maximum=1000000, step=1, value=300000,
label="Seed Value",
info="Different seeds can lead to noticeably different object removal results (default: 300000)"
)
num_inference_steps = gr.Slider(
minimum=1, maximum=40, step=1, value=20,
label="Num Inference Steps",
info="Higher values may improve quality but take longer (default: 20)"
)
with gr.Column():
output_image_component = gr.Image(
type='pil', image_mode='RGB', label='Output', format="png", elem_id="input-image")
output_compare_image_component = gr.ImageSlider(
label="Comparison",
type="pil",
format='png',
elem_id="compare-image"
)
input_image.upload(
fn=upload_and_reset,
inputs=[input_image, interactive_state],
outputs=[
image_state,
image_info,
input_image,
interactive_state,
click_state,
mask_dropdown,
]
)
# click select image to get mask using sam
input_image.select(
fn=sam_refine,
inputs=[image_state, point_prompt, click_state],
outputs=[input_image, image_state, click_state]
)
# add different mask
add_mask_button.click(
fn=add_multi_mask,
inputs=[image_state, interactive_state, mask_dropdown],
outputs=[interactive_state, mask_dropdown, input_image, click_state]
)
remove_mask_button.click(
fn=remove_multi_mask,
inputs=[interactive_state, click_state, image_state],
outputs=[interactive_state, mask_dropdown, input_image, click_state]
)
# points clear
clear_button_click.click(
fn = clear_click,
inputs = [image_state, click_state,],
outputs = [input_image, click_state],
)
submit_button_component.click(
fn=process,
inputs=[
image_state,
interactive_state,
mask_dropdown,
guidance_scale,
seed,
num_inference_steps,
strength
],
outputs=[
output_image_component, output_compare_image_component
]
)
with gr.Accordion("📕 Video Tutorial (click to expand)", open=False, elem_classes="custom-bg"):
with gr.Row():
gr.Video(value="/home/user/app/hugging_face/assets/tutorial.mp4", elem_classes="video")
gr.Markdown("---")
gr.Markdown("## Examples")
example_images = [
os.path.join(os.path.dirname(__file__), "examples", f"test{i}.png")
for i in range(10)
]
examples_data = [
[example_images[i], None] for i in range(len(example_images))
]
examples = gr.Examples(
examples=examples_data,
inputs=[input_image, interactive_state],
outputs=[image_state, image_info, input_image,
interactive_state, click_state, mask_dropdown],
fn=upload_and_reset,
run_on_click=True,
cache_examples=False,
label="Click below to load example images"
)
gr.Markdown(article)
def pre_update_input_image():
return gr.update(value=None)
demo.load(
fn=pre_update_input_image,
inputs=[],
outputs=[input_image]
)
demo.launch(debug=True, show_error=True)