SAMSEGMENT / app.py
fireedman's picture
Update app.py
605391e verified
raw
history blame
2.98 kB
from typing import List
import gradio as gr
import numpy as np
import supervision as sv
import torch
from PIL import Image
from transformers import pipeline
# Global Variables
MARKDOWN = """
# SAM - Softly Activated Masks
"""
EXAMPLES = [
["https://media.roboflow.com/notebooks/examples/dog.jpeg", "dog", 0.5],
["https://media.roboflow.com/notebooks/examples/dog.jpeg", "building", 0.5],
["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "jacket", 0.5],
["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "coffee", 0.6],
]
MIN_AREA_THRESHOLD = 0.01
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize SAM Generator with exception handling
try:
SAM_GENERATOR = pipeline(
task="mask-generation",
model="facebook/sam-vit-large",
device=DEVICE
)
except Exception as e:
print(f"Error initializing SAM generator: {e}")
# Mask Annotators
SEMITRANSPARENT_MASK_ANNOTATOR = sv.MaskAnnotator(
color=sv.Color.red(),
color_lookup=sv.ColorLookup.INDEX
)
SOLID_MASK_ANNOTATOR = sv.MaskAnnotator(
color=sv.Color.white(),
color_lookup=sv.ColorLookup.INDEX,
opacity=1
)
# Functions
def run_sam(image_rgb_pil: Image.Image) -> sv.Detections:
try:
outputs = SAM_GENERATOR(image_rgb_pil, points_per_batch=32)
mask = np.array(outputs['masks'], dtype=np.uint8)
return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
except Exception as e:
print(f"Error running SAM model: {e}")
return sv.Detections(xyxy=[], mask=[])
def reverse_mask_image(image: np.ndarray, mask: np.ndarray, gray_value=128):
gray_color = np.array([gray_value, gray_value, gray_value], dtype=np.uint8)
return np.where(mask[..., None], image, gray_color)
def inference(image_rgb_pil: Image.Image) -> List[Image.Image]:
width, height = image_rgb_pil.size
area = width * height
detections = run_sam(image_rgb_pil)
detections = detections[detections.area / area > MIN_AREA_THRESHOLD]
blank_image = Image.new("RGB", (width, height), "black")
return [
SEMITRANSPARENT_MASK_ANNOTATOR.annotate(image_rgb_pil, detections),
SOLID_MASK_ANNOTATOR.annotate(blank_image, detections)
]
#************
#GRADIO CONSTRUCTION
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
with gr.Column():
input_image = gr.Image(image_mode='RGB', type='pil', height=500)
submit_button = gr.Button("Pruébalo!!!")
gallery = gr.Gallery(label="Result", object_fit="scale-down", preview=True)
with gr.Row():
gr.Examples(
examples=EXAMPLES,
fn=inference,
inputs=[input_image],
outputs=[gallery],
cache_examples=False,
run_on_click=True
)
submit_button.click(
inference,
inputs=[input_image],
outputs=gallery
)
demo.launch(debug=False, show_error=True)