import gradio as gr import numpy as np from PIL import Image import tempfile import os # Function to create the binary mask from the ImageEditor's output def create_binary_mask(im_dict): if im_dict is None or im_dict["background"] is None: print("No background image found.") # Return a small blank placeholder and None for the file path blank_preview = np.zeros((768, 1024), dtype=np.uint8) return blank_preview, None background_img = im_dict["background"] h, w, _ = background_img.shape # Get original dimensions (Height, Width, Channels) print(f"Original image dimensions: H={h}, W={w}") # Check if any drawing layer exists and is not None if not im_dict["layers"] or im_dict["layers"][0] is None: print("No drawing layer found. Generating blank mask.") # Nothing drawn yet, return a black mask of the original size mask = np.zeros((h, w), dtype=np.uint8) filepath = None # No file to download as nothing was drawn else: # Use the first layer (index 0) which usually contains the drawing layer = im_dict["layers"][0] print(f"Drawing layer dimensions: H={layer.shape[0]}, W={layer.shape[1]}") # Ensure layer dimensions match background (Gradio ImageEditor usually handles this) if layer.shape[0] != h or layer.shape[1] != w: print(f"Warning: Layer size ({layer.shape[0]}x{layer.shape[1]}) doesn't match background ({h}x{w}). This shouldn't happen.") # Handle potential mismatch if necessary, though unlikely with default editor behavior # For now, proceed assuming they match or the layer is the correct reference # Layer is RGBA, extract the Alpha channel (index 3) alpha_channel = layer[:, :, 3] # Create binary mask: white (255) where alpha > 0 (drawn), black (0) otherwise mask = np.where(alpha_channel > 0, 255, 0).astype(np.uint8) print(f"Generated binary mask dimensions: H={mask.shape[0]}, W={mask.shape[1]}") # Save the mask to a temporary PNG file for download try: # Create a temporary file path with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile: filepath = tmpfile.name # Save the NumPy array as a PNG image using PIL pil_image = Image.fromarray(mask) pil_image.save(filepath, format="PNG") print(f"Mask saved temporarily to: {filepath}") except Exception as e: print(f"Error saving mask to temporary file: {e}") filepath = None # Indicate failure to save # Return a blank mask in case of saving error mask = np.zeros((h, w), dtype=np.uint8) # Return the mask NumPy array for preview and the filepath for download # The DownloadButton component will become active/functional if filepath is not None return mask, filepath # --- Gradio App Layout --- with gr.Blocks() as demo: gr.Markdown("## Binary Mask Generator") gr.Markdown( "Upload or paste an image. Use the brush tool (select it!) to draw the area " "you want to mask. Click 'Generate Mask' to see the result and download it." ) with gr.Row(): # --- Left Column --- with gr.Column(scale=1): # Adjust scale as needed image_editor = gr.ImageEditor( label="Draw on Image", # type="numpy" is essential for processing layers type="numpy", # DON'T set crop_size, height, or width to keep original dimensions # sources allow upload, paste, webcam etc. sources=["upload"], # Set a default brush for clarity (optional, but helpful) brush=gr.Brush(colors=["#FF0000"], color_mode="fixed"), # Red fixed brush interactive=True, canvas_size=(768, 1024) ) generate_button = gr.Button("Generate Mask", variant="primary") # --- Right Column --- with gr.Column(scale=1): # Adjust scale as needed mask_preview = gr.Image( label="Binary Mask Preview", # Use numpy for consistency, PIL would also work type="numpy", interactive=False, # Preview is not interactive ) # Download button - its value (the file path) is set by the function's output download_button = gr.DownloadButton( label="Download Mask (PNG)", interactive=True, # Button starts interactive ) # --- Event Handling --- generate_button.click( fn=create_binary_mask, inputs=[image_editor], # Output 1 goes to mask_preview (image data) # Output 2 goes to download_button (file path for the 'value' argument) outputs=[mask_preview, download_button] ) # --- Launch the App --- if __name__ == "__main__": # Cleaning up old temp files on startup (optional but good practice) temp_dir = tempfile.gettempdir() for item in os.listdir(temp_dir): if item.endswith(".png") and item.startswith("tmp"): # Be specific to avoid deleting wrong files try: os.remove(os.path.join(temp_dir, item)) except Exception: pass # Ignore if file is locked etc. demo.launch(share=True)