AIRider's picture
Update src/app.py
d5de94c verified
raw
history blame
6.02 kB
import tempfile
import time
from typing import Any
from collections.abc import Sequence
import gradio as gr
import numpy as np
import pillow_heif
import spaces
import torch
from gradio_image_annotation import image_annotator
from gradio_imageslider import ImageSlider
from PIL import Image
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
from refiners.fluxion.utils import no_grad
from refiners.solutions import BoxSegmenter
BoundingBox = tuple[int, int, int, int]
pillow_heif.register_heif_opener()
pillow_heif.register_avif_opener()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize segmenter
segmenter = BoxSegmenter(device="cpu")
segmenter.device = device
segmenter.model = segmenter.model.to(device=segmenter.device)
def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None:
if not bboxes:
return None
for bbox in bboxes:
assert len(bbox) == 4
assert all(isinstance(x, int) for x in bbox)
return (
min(bbox[0] for bbox in bboxes),
min(bbox[1] for bbox in bboxes),
max(bbox[2] for bbox in bboxes),
max(bbox[3] for bbox in bboxes),
)
def apply_mask(
img: Image.Image,
mask_img: Image.Image,
defringe: bool = True,
) -> Image.Image:
assert img.size == mask_img.size
img = img.convert("RGB")
mask_img = mask_img.convert("L")
if defringe:
# Mitigate edge halo effects via color decontamination
rgb, alpha = np.asarray(img) / 255.0, np.asarray(mask_img) / 255.0
foreground = estimate_foreground_ml(rgb, alpha)
img = Image.fromarray((foreground * 255).astype("uint8"))
result = Image.new("RGBA", img.size)
result.paste(img, (0, 0), mask_img)
return result
@spaces.GPU
def _gpu_process(
img: Image.Image,
bbox: BoundingBox | None,
) -> tuple[Image.Image, BoundingBox | None, list[str]]:
time_log: list[str] = []
t0 = time.time()
mask = segmenter(img, bbox)
time_log.append(f"segment: {time.time() - t0}")
return mask, bbox, time_log
def _process(
img: Image.Image,
bbox: BoundingBox | None,
) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
# enforce max dimensions for pymatting performance reasons
if img.width > 2048 or img.height > 2048:
orig_res = max(img.width, img.height)
img.thumbnail((2048, 2048))
if isinstance(bbox, tuple):
x0, y0, x1, y1 = (int(x * 2048 / orig_res) for x in bbox)
bbox = (x0, y0, x1, y1)
mask, bbox, time_log = _gpu_process(img, bbox)
t0 = time.time()
masked_alpha = apply_mask(img, mask, defringe=True)
time_log.append(f"crop: {time.time() - t0}")
print(", ".join(time_log))
masked_rgb = Image.alpha_composite(Image.new("RGBA", masked_alpha.size, "white"), masked_alpha)
thresholded = mask.point(lambda p: 255 if p > 10 else 0)
bbox = thresholded.getbbox()
to_dl = masked_alpha.crop(bbox)
temp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
to_dl.save(temp, format="PNG")
temp.close()
return (img, masked_rgb), gr.DownloadButton(value=temp.name, interactive=True)
def process_bbox(prompts: dict[str, Any]) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
assert isinstance(img := prompts["image"], Image.Image)
assert isinstance(boxes := prompts["boxes"], list)
if len(boxes) == 1:
assert isinstance(box := boxes[0], dict)
bbox = tuple(box[k] for k in ["xmin", "ymin", "xmax", "ymax"])
else:
assert len(boxes) == 0
bbox = None
return _process(img, bbox)
def on_change_bbox(prompts: dict[str, Any] | None):
return gr.update(interactive=prompts is not None)
TITLE = """
<center>
<h1 style="font-size: 1.5rem; margin-bottom: 0.5rem;">
Object Cutter With Bounding Box
</h1>
<p>
Create high-quality HD cutouts for any object in your image using bounding box selection.
<br>
The object will be available on a transparent background, ready to paste elsewhere.
</p>
<p>
This space uses the
<a
href="https://huggingface.co/finegrain/finegrain-box-segmenter"
target="_blank"
>Finegrain Box Segmenter model</a>,
trained with a mix of natural data curated by Finegrain and
<a
href="https://huggingface.co/datasets/Nfiniteai/product-masks-sample"
target="_blank"
>synthetic data provided by Nfinite</a>.
</p>
</center>
"""
with gr.Blocks() as demo:
gr.HTML(TITLE)
with gr.Row():
with gr.Column():
annotator = image_annotator(
image_type="pil",
disable_edit_boxes=True,
show_download_button=False,
show_share_button=False,
single_box=True,
label="Input",
)
btn = gr.ClearButton(value="Cut Out Object", interactive=False)
with gr.Column():
oimg = ImageSlider(label="Before / After", show_download_button=False)
dlbt = gr.DownloadButton("Download Cutout", interactive=False)
btn.add(oimg)
annotator.change(
fn=on_change_bbox,
inputs=[annotator],
outputs=[btn],
)
btn.click(
fn=process_bbox,
inputs=[annotator],
outputs=[oimg, dlbt],
)
examples = [
{
"image": "examples/potted-plant.jpg",
"boxes": [{"xmin": 51, "ymin": 511, "xmax": 639, "ymax": 1255}],
},
{
"image": "examples/chair.jpg",
"boxes": [{"xmin": 98, "ymin": 330, "xmax": 973, "ymax": 1468}],
},
{
"image": "examples/black-lamp.jpg",
"boxes": [{"xmin": 88, "ymin": 148, "xmax": 700, "ymax": 1414}],
},
]
ex = gr.Examples(
examples=examples,
inputs=[annotator],
outputs=[oimg, dlbt],
fn=process_bbox,
cache_examples=True,
)
demo.launch(share=False)