Boni98 commited on
Commit
bdba3b8
·
1 Parent(s): af95858

Added application file

Browse files
Files changed (1) hide show
  1. app.py +118 -29
app.py CHANGED
@@ -4,42 +4,131 @@ from PIL import Image
4
  import tempfile
5
  import os
6
 
7
- def generate_binary_mask(image_data):
8
- if image_data is None or "layers" not in image_data or not image_data["layers"]:
9
- raise gr.Error("Please draw a mask before generating!")
 
10
 
11
- mask = image_data["layers"][0]
12
- mask_array = np.array(mask)
 
13
 
14
- if np.all(mask_array < 10):
15
- raise gr.Error("The mask is empty! Please draw something.")
 
 
 
 
 
 
 
 
16
 
17
- # Binary mask logic
18
- is_black = np.all(mask_array < 10, axis=2)
19
- binary_mask = Image.fromarray(((~is_black) * 255).astype(np.uint8))
20
 
21
- # Save to temporary file
22
- temp_dir = tempfile.mkdtemp()
23
- output_path = os.path.join(temp_dir, "binary_mask.png")
24
- binary_mask.save(output_path)
 
 
 
 
 
 
25
 
26
- return binary_mask, output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- with gr.Blocks() as app:
29
  with gr.Row():
30
- with gr.Column():
31
- image_input = gr.ImageMask(
32
- label="Upload or Paste Image, then draw mask",
33
- type="pil",
34
- height=None,
35
- width=None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  )
37
- generate_btn = gr.Button("Generate Mask")
38
-
39
- with gr.Column():
40
- mask_preview = gr.Image(label="Mask Preview", type="pil")
41
- output_file = gr.File(label="Download Mask")
 
 
 
 
 
 
 
 
 
42
 
43
- generate_btn.click(fn=generate_binary_mask, inputs=image_input, outputs=[mask_preview, output_file])
 
 
 
 
 
 
 
 
 
44
 
45
- app.launch()
 
4
  import tempfile
5
  import os
6
 
7
+ # Function to create the binary mask from the ImageEditor's output
8
+ def create_binary_mask(im_dict):
9
+ """
10
+ Generates a binary mask from the drawing layer of the gr.ImageEditor output.
11
 
12
+ Args:
13
+ im_dict (dict): The dictionary output from gr.ImageEditor, containing
14
+ 'background', 'layers', and 'composite'.
15
 
16
+ Returns:
17
+ tuple: A tuple containing:
18
+ - np.ndarray: The binary mask image (H, W) as a NumPy array (0 or 255).
19
+ - str or None: The filepath to the saved PNG mask for download, or None if no mask generated.
20
+ """
21
+ if im_dict is None or im_dict["background"] is None:
22
+ print("No background image found.")
23
+ # Return a small blank placeholder and None for the file path
24
+ blank_preview = np.zeros((100, 100), dtype=np.uint8)
25
+ return blank_preview, None
26
 
27
+ background_img = im_dict["background"]
28
+ h, w, _ = background_img.shape # Get original dimensions (Height, Width, Channels)
29
+ print(f"Original image dimensions: H={h}, W={w}")
30
 
31
+ # Check if any drawing layer exists and is not None
32
+ if not im_dict["layers"] or im_dict["layers"][0] is None:
33
+ print("No drawing layer found. Generating blank mask.")
34
+ # Nothing drawn yet, return a black mask of the original size
35
+ mask = np.zeros((h, w), dtype=np.uint8)
36
+ filepath = None # No file to download as nothing was drawn
37
+ else:
38
+ # Use the first layer (index 0) which usually contains the drawing
39
+ layer = im_dict["layers"][0]
40
+ print(f"Drawing layer dimensions: H={layer.shape[0]}, W={layer.shape[1]}")
41
 
42
+ # Ensure layer dimensions match background (Gradio ImageEditor usually handles this)
43
+ if layer.shape[0] != h or layer.shape[1] != w:
44
+ print(f"Warning: Layer size ({layer.shape[0]}x{layer.shape[1]}) doesn't match background ({h}x{w}). This shouldn't happen.")
45
+ # Handle potential mismatch if necessary, though unlikely with default editor behavior
46
+ # For now, proceed assuming they match or the layer is the correct reference
47
+
48
+ # Layer is RGBA, extract the Alpha channel (index 3)
49
+ alpha_channel = layer[:, :, 3]
50
+
51
+ # Create binary mask: white (255) where alpha > 0 (drawn), black (0) otherwise
52
+ mask = np.where(alpha_channel > 0, 255, 0).astype(np.uint8)
53
+ print(f"Generated binary mask dimensions: H={mask.shape[0]}, W={mask.shape[1]}")
54
+
55
+ # Save the mask to a temporary PNG file for download
56
+ try:
57
+ # Create a temporary file path
58
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
59
+ filepath = tmpfile.name
60
+
61
+ # Save the NumPy array as a PNG image using PIL
62
+ pil_image = Image.fromarray(mask)
63
+ pil_image.save(filepath, format="PNG")
64
+ print(f"Mask saved temporarily to: {filepath}")
65
+
66
+ except Exception as e:
67
+ print(f"Error saving mask to temporary file: {e}")
68
+ filepath = None # Indicate failure to save
69
+ # Return a blank mask in case of saving error
70
+ mask = np.zeros((h, w), dtype=np.uint8)
71
+
72
+ # Return the mask NumPy array for preview and the filepath for download
73
+ # The DownloadButton component will become active/functional if filepath is not None
74
+ return mask, filepath
75
+
76
+ # --- Gradio App Layout ---
77
+ with gr.Blocks() as demo:
78
+ gr.Markdown("## Binary Mask Generator")
79
+ gr.Markdown(
80
+ "Upload or paste an image. Use the brush tool (select it!) to draw the area "
81
+ "you want to mask. Click 'Generate Mask' to see the result and download it."
82
+ )
83
 
 
84
  with gr.Row():
85
+ # --- Left Column ---
86
+ with gr.Column(scale=1): # Adjust scale as needed
87
+ image_editor = gr.ImageEditor(
88
+ label="Draw on Image",
89
+ # type="numpy" is essential for processing layers
90
+ type="numpy",
91
+ # DON'T set crop_size, height, or width to keep original dimensions
92
+ # sources allow upload, paste, webcam etc.
93
+ sources=["upload"],
94
+ # Set a default brush for clarity (optional, but helpful)
95
+ brush=gr.Brush(colors=["#FF0000"], color_mode="fixed"), # Red fixed brush
96
+ interactive=True,
97
+ )
98
+ generate_button = gr.Button("Generate Mask", variant="primary")
99
+
100
+ # --- Right Column ---
101
+ with gr.Column(scale=1): # Adjust scale as needed
102
+ mask_preview = gr.Image(
103
+ label="Binary Mask Preview",
104
+ # Use numpy for consistency, PIL would also work
105
+ type="numpy",
106
+ interactive=False, # Preview is not interactive
107
  )
108
+ # Download button - its value (the file path) is set by the function's output
109
+ download_button = gr.DownloadButton(
110
+ label="Download Mask (PNG)",
111
+ interactive=True, # Button starts interactive
112
+ )
113
+
114
+ # --- Event Handling ---
115
+ generate_button.click(
116
+ fn=create_binary_mask,
117
+ inputs=[image_editor],
118
+ # Output 1 goes to mask_preview (image data)
119
+ # Output 2 goes to download_button (file path for the 'value' argument)
120
+ outputs=[mask_preview, download_button]
121
+ )
122
 
123
+ # --- Launch the App ---
124
+ if __name__ == "__main__":
125
+ # Cleaning up old temp files on startup (optional but good practice)
126
+ temp_dir = tempfile.gettempdir()
127
+ for item in os.listdir(temp_dir):
128
+ if item.endswith(".png") and item.startswith("tmp"): # Be specific to avoid deleting wrong files
129
+ try:
130
+ os.remove(os.path.join(temp_dir, item))
131
+ except Exception:
132
+ pass # Ignore if file is locked etc.
133
 
134
+ demo.launch()