File size: 2,980 Bytes
f6b477c
 
 
 
 
 
605391e
f6b477c
605391e
f6b477c
605391e
f6b477c
 
 
 
 
 
 
 
 
605391e
 
 
 
 
 
 
 
 
 
 
f6b477c
605391e
f6b477c
605391e
 
f6b477c
 
 
605391e
 
 
f6b477c
 
605391e
 
 
 
 
 
 
 
 
f6b477c
 
605391e
f6b477c
 
605391e
f6b477c
 
 
605391e
 
 
f6b477c
 
605391e
 
f6b477c
 
 
 
 
 
 
605391e
f6b477c
605391e
 
f6b477c
 
605391e
 
 
 
 
 
f6b477c
 
 
605391e
 
f6b477c
 
605391e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import List
import gradio as gr
import numpy as np
import supervision as sv
import torch
from PIL import Image
from transformers import pipeline

# Global Variables
MARKDOWN = """
# SAM - Softly Activated Masks
"""
EXAMPLES = [
    ["https://media.roboflow.com/notebooks/examples/dog.jpeg", "dog", 0.5],
    ["https://media.roboflow.com/notebooks/examples/dog.jpeg", "building", 0.5],
    ["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "jacket", 0.5],
    ["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "coffee", 0.6],
]

MIN_AREA_THRESHOLD = 0.01
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize SAM Generator with exception handling
try:
    SAM_GENERATOR = pipeline(
        task="mask-generation",
        model="facebook/sam-vit-large",
        device=DEVICE
    )
except Exception as e:
    print(f"Error initializing SAM generator: {e}")

# Mask Annotators
SEMITRANSPARENT_MASK_ANNOTATOR = sv.MaskAnnotator(
    color=sv.Color.red(),
    color_lookup=sv.ColorLookup.INDEX
)

SOLID_MASK_ANNOTATOR = sv.MaskAnnotator(
    color=sv.Color.white(),
    color_lookup=sv.ColorLookup.INDEX,
    opacity=1
)

# Functions
def run_sam(image_rgb_pil: Image.Image) -> sv.Detections:
    try:
        outputs = SAM_GENERATOR(image_rgb_pil, points_per_batch=32)
        mask = np.array(outputs['masks'], dtype=np.uint8)
        return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
    except Exception as e:
        print(f"Error running SAM model: {e}")
        return sv.Detections(xyxy=[], mask=[])

def reverse_mask_image(image: np.ndarray, mask: np.ndarray, gray_value=128):
    gray_color = np.array([gray_value, gray_value, gray_value], dtype=np.uint8)
    return np.where(mask[..., None], image, gray_color)

def inference(image_rgb_pil: Image.Image) -> List[Image.Image]:
    width, height = image_rgb_pil.size
    area = width * height

    detections = run_sam(image_rgb_pil)
    detections = detections[detections.area / area > MIN_AREA_THRESHOLD]

    blank_image = Image.new("RGB", (width, height), "black")
    return [
        SEMITRANSPARENT_MASK_ANNOTATOR.annotate(image_rgb_pil, detections),
        SOLID_MASK_ANNOTATOR.annotate(blank_image, detections)
    ]
#************    
#GRADIO CONSTRUCTION
with gr.Blocks() as demo:
    gr.Markdown(MARKDOWN)
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(image_mode='RGB', type='pil', height=500)
            submit_button = gr.Button("Pruébalo!!!")
        gallery = gr.Gallery(label="Result", object_fit="scale-down", preview=True)

    with gr.Row():
        gr.Examples(
            examples=EXAMPLES,
            fn=inference,
            inputs=[input_image],
            outputs=[gallery],
            cache_examples=False,
            run_on_click=True
        )
    submit_button.click(
        inference,
        inputs=[input_image],
        outputs=gallery
    )

demo.launch(debug=False, show_error=True)