import gradio as gr import torch import cv2 import numpy as np from transformers import SamModel, SamProcessor, BlipProcessor, BlipForConditionalGeneration from PIL import Image from scipy.ndimage import label, center_of_mass # Set up device device = "cuda" if torch.cuda.is_available() else "cpu" # Load SAM model and processor sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(device) sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base") # Load BLIP model and processor for image-to-text blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device) def process_mask(mask, target_size): if mask.ndim > 2: mask = mask.squeeze() if mask.ndim > 2: mask = mask[0] mask = (mask > 0.5).astype(np.uint8) * 255 mask_image = Image.fromarray(mask) mask_image = mask_image.resize(target_size, Image.NEAREST) return np.array(mask_image) > 0 def is_cat_like(mask, image_area): labeled, num_features = label(mask) if num_features == 0: return False largest_component = (labeled == (np.bincount(labeled.flatten())[1:].argmax() + 1)) area = largest_component.sum() # Check if the area is reasonable for a cat (between 5% and 30% of image) if not (0.05 * image_area < area < 0.3 * image_area): return False # Check if the shape is roughly elliptical cy, cx = center_of_mass(largest_component) major_axis = max(largest_component.shape) minor_axis = min(largest_component.shape) aspect_ratio = major_axis / minor_axis return 1.5 < aspect_ratio < 3 # Most cats have an aspect ratio in this range def segment_image(input_image, object_name): try: if input_image is None: return None, "Please upload an image before submitting." input_image = Image.fromarray(input_image).convert("RGB") original_size = input_image.size if not original_size or 0 in original_size: return None, "Invalid image size. Please upload a different image." # Generate detailed image caption blip_inputs = blip_processor(input_image, return_tensors="pt").to(device) caption = blip_model.generate(**blip_inputs, max_length=50) caption_text = blip_processor.decode(caption[0], skip_special_tokens=True) # Process the image with SAM sam_inputs = sam_processor(input_image, return_tensors="pt").to(device) # Generate masks with torch.no_grad(): sam_outputs = sam_model(**sam_inputs) # Post-process masks masks = sam_processor.image_processor.post_process_masks( sam_outputs.pred_masks.cpu(), sam_inputs["original_sizes"].cpu(), sam_inputs["reshaped_input_sizes"].cpu() ) # Find the mask that best matches the specified object best_mask = None best_score = -1 image_area = original_size[0] * original_size[1] cat_related_words = ['cat', 'kitten', 'feline', 'tabby', 'kitty'] caption_contains_cat = any(word in caption_text.lower() for word in cat_related_words) for mask in masks[0]: mask_binary = mask.numpy() > 0.5 if is_cat_like(mask_binary, image_area) and caption_contains_cat: mask_area = mask_binary.sum() if mask_area > best_score: best_mask = mask_binary best_score = mask_area if best_mask is None: return input_image, f"Could not find a suitable '{object_name}' in the image." combined_mask = process_mask(best_mask, original_size) # Overlay the mask on the original image result_image = np.array(input_image) mask_rgb = np.zeros_like(result_image) mask_rgb[combined_mask] = [255, 0, 0] # Red color for the mask result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0) return result_image, f"Segmented '{object_name}' in the image." except Exception as e: return None, f"An error occurred: {str(e)}" # Create Gradio interface iface = gr.Interface( fn=segment_image, inputs=[ gr.Image(type="numpy", label="Upload an image"), gr.Textbox(label="Specify object to segment (e.g., dog, cat, grass)") ], outputs=[ gr.Image(type="numpy", label="Segmented Image"), gr.Textbox(label="Status") ], title="Segment Anything Model (SAM) with Object Specification", description="Upload an image and specify an object to segment." ) # Launch the interface iface.launch()