Spaces:
Running
Running
from typing import List | |
import gradio as gr | |
import numpy as np | |
import supervision as sv | |
import torch | |
from PIL import Image | |
from transformers import pipeline | |
# Global Variables | |
MARKDOWN = """ | |
# SAM - Softly Activated Masks | |
""" | |
EXAMPLES = [ | |
["https://media.roboflow.com/notebooks/examples/dog.jpeg", "dog", 0.5], | |
["https://media.roboflow.com/notebooks/examples/dog.jpeg", "building", 0.5], | |
["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "jacket", 0.5], | |
["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "coffee", 0.6], | |
] | |
MIN_AREA_THRESHOLD = 0.01 | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Initialize SAM Generator with exception handling | |
try: | |
SAM_GENERATOR = pipeline( | |
task="mask-generation", | |
model="facebook/sam-vit-large", | |
device=DEVICE | |
) | |
except Exception as e: | |
print(f"Error initializing SAM generator: {e}") | |
# Mask Annotators | |
SEMITRANSPARENT_MASK_ANNOTATOR = sv.MaskAnnotator( | |
color=sv.Color.red(), | |
color_lookup=sv.ColorLookup.INDEX | |
) | |
SOLID_MASK_ANNOTATOR = sv.MaskAnnotator( | |
color=sv.Color.white(), | |
color_lookup=sv.ColorLookup.INDEX, | |
opacity=1 | |
) | |
# Functions | |
def run_sam(image_rgb_pil: Image.Image) -> sv.Detections: | |
try: | |
outputs = SAM_GENERATOR(image_rgb_pil, points_per_batch=32) | |
mask = np.array(outputs['masks'], dtype=np.uint8) | |
return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask) | |
except Exception as e: | |
print(f"Error running SAM model: {e}") | |
return sv.Detections(xyxy=[], mask=[]) | |
def reverse_mask_image(image: np.ndarray, mask: np.ndarray, gray_value=128): | |
gray_color = np.array([gray_value, gray_value, gray_value], dtype=np.uint8) | |
return np.where(mask[..., None], image, gray_color) | |
def inference(image_rgb_pil: Image.Image) -> List[Image.Image]: | |
width, height = image_rgb_pil.size | |
area = width * height | |
detections = run_sam(image_rgb_pil) | |
detections = detections[detections.area / area > MIN_AREA_THRESHOLD] | |
blank_image = Image.new("RGB", (width, height), "black") | |
return [ | |
SEMITRANSPARENT_MASK_ANNOTATOR.annotate(image_rgb_pil, detections), | |
SOLID_MASK_ANNOTATOR.annotate(blank_image, detections) | |
] | |
#************ | |
#GRADIO CONSTRUCTION | |
with gr.Blocks() as demo: | |
gr.Markdown(MARKDOWN) | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(image_mode='RGB', type='pil', height=500) | |
submit_button = gr.Button("Pruébalo!!!") | |
gallery = gr.Gallery(label="Result", object_fit="scale-down", preview=True) | |
with gr.Row(): | |
gr.Examples( | |
examples=EXAMPLES, | |
fn=inference, | |
inputs=[input_image], | |
outputs=[gallery], | |
cache_examples=False, | |
run_on_click=True | |
) | |
submit_button.click( | |
inference, | |
inputs=[input_image], | |
outputs=gallery | |
) | |
demo.launch(debug=False, show_error=True) |