import gradio as gr import torch from transformers import AutoProcessor, AutoModel from PIL import Image, ImageDraw, ImageFont import numpy as np import random import os import wget import traceback # --- Configuration & Model Loading --- # Device Selection with fallback DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Simplified check print(f"Using device: {DEVICE}") # --- CLIP Setup --- CLIP_MODEL_ID = "openai/clip-vit-base-patch32" clip_processor = None clip_model = None def load_clip_model(): global clip_processor, clip_model if clip_processor is None: try: print(f"Loading CLIP processor: {CLIP_MODEL_ID}...") clip_processor = AutoProcessor.from_pretrained(CLIP_MODEL_ID) print("CLIP processor loaded.") except Exception as e: print(f"Error loading CLIP processor: {e}") traceback.print_exc() # Print traceback return False if clip_model is None: try: print(f"Loading CLIP model: {CLIP_MODEL_ID}...") clip_model = AutoModel.from_pretrained(CLIP_MODEL_ID).to(DEVICE) print(f"CLIP model loaded to {DEVICE}.") except Exception as e: print(f"Error loading CLIP model: {e}") traceback.print_exc() # Print traceback return False return True # --- FastSAM Setup --- FASTSAM_CHECKPOINT = "FastSAM-s.pt" FASTSAM_CHECKPOINT_URL = f"https://huggingface.co/CASIA-IVA-Lab/FastSAM-s/resolve/main/{FASTSAM_CHECKPOINT}" fastsam_model = None fastsam_lib_imported = False FastSAM = None # Define placeholders FastSAMPrompt = None # Define placeholders def check_and_import_fastsam(): global fastsam_lib_imported, FastSAM, FastSAMPrompt # Make sure globals are modified if not fastsam_lib_imported: try: from fastsam import FastSAM as FastSAM_lib, FastSAMPrompt as FastSAMPrompt_lib # Use temp names FastSAM = FastSAM_lib # Assign to global FastSAMPrompt = FastSAMPrompt_lib # Assign to global fastsam_lib_imported = True print("fastsam library imported successfully.") except ImportError as e: print(f"Error: 'fastsam' library not found. Please install it: pip install git+https://github.com/CASIA-IVA-Lab/FastSAM.git") print(f"ImportError: {e}") fastsam_lib_imported = False except Exception as e: print(f"Unexpected error during fastsam import: {e}") traceback.print_exc() fastsam_lib_imported = False return fastsam_lib_imported def download_fastsam_weights(retries=3): if not os.path.exists(FASTSAM_CHECKPOINT): print(f"Downloading FastSAM weights: {FASTSAM_CHECKPOINT} from {FASTSAM_CHECKPOINT_URL}...") for attempt in range(retries): try: # Ensure the directory exists if FASTSAM_CHECKPOINT includes a path os.makedirs(os.path.dirname(FASTSAM_CHECKPOINT) or '.', exist_ok=True) wget.download(FASTSAM_CHECKPOINT_URL, FASTSAM_CHECKPOINT) print("FastSAM weights downloaded.") return True # Return True on successful download except Exception as e: print(f"Attempt {attempt + 1}/{retries} failed to download FastSAM weights: {e}") if os.path.exists(FASTSAM_CHECKPOINT): # Cleanup partial download try: os.remove(FASTSAM_CHECKPOINT) except OSError: pass if attempt + 1 == retries: print("Failed to download weights after all attempts.") return False return False # Should not be reached if loop completes, but added for clarity else: print("FastSAM weights already exist.") return True # Weights exist def load_fastsam_model(): global fastsam_model if fastsam_model is None: if not check_and_import_fastsam(): print("Cannot load FastSAM model due to library import failure.") return False if download_fastsam_weights(): # Ensure FastSAM class is available (might fail if import failed earlier but file exists) if FastSAM is None: print("FastSAM class not available, check import status.") return False try: print(f"Loading FastSAM model: {FASTSAM_CHECKPOINT}...") # Instantiate the imported class fastsam_model = FastSAM(FASTSAM_CHECKPOINT) # Move model to device *after* initialization (common practice) # Note: Check FastSAM docs if it needs explicit .to(DEVICE) or handles it internally # fastsam_model.model.to(DEVICE) # Example if needed, adjust based on FastSAM structure print("FastSAM model loaded.") return True except Exception as e: print(f"Error loading FastSAM model weights or initializing: {e}") traceback.print_exc() return False else: print("FastSAM weights not found or download failed.") return False # Model already loaded return True # --- Processing Functions --- def run_clip_zero_shot(image: Image.Image, text_labels: str): # Keep CLIP as is, seems less likely to be the primary issue if not isinstance(image, Image.Image): print(f"CLIP input is not a PIL Image, type: {type(image)}") # Try to convert if it's a numpy array (common from Gradio) if isinstance(image, np.ndarray): try: image = Image.fromarray(image) print("Converted numpy input to PIL Image for CLIP.") except Exception as e: print(f"Failed to convert numpy array to PIL Image: {e}") return "Error: Invalid image input format.", None else: return "Error: Please provide a valid image.", None if clip_model is None or clip_processor is None: if not load_clip_model(): # Return None for the image part on critical error return "Error: CLIP Model could not be loaded.", None if not text_labels: # Return empty dict and original image if no labels return {}, image labels = [label.strip() for label in text_labels.split(',') if label.strip()] if not labels: # Return empty dict and original image if no valid labels return {}, image print(f"Running CLIP zero-shot classification with labels: {labels}") try: # Ensure image is RGB if image.mode != "RGB": print(f"Converting image from {image.mode} to RGB for CLIP.") image = image.convert("RGB") inputs = clip_processor(text=labels, images=image, return_tensors="pt", padding=True).to(DEVICE) with torch.no_grad(): outputs = clip_model(**inputs) # Calculate probabilities logits_per_image = outputs.logits_per_image # B x N_labels probs = logits_per_image.softmax(dim=1) # Softmax over labels # Create confidences dictionary confidences = {labels[i]: float(probs[0, i].item()) for i in range(len(labels))} print(f"CLIP Confidences: {confidences}") # Return confidences and the original (potentially converted) image return confidences, image except Exception as e: print(f"Error during CLIP processing: {e}") traceback.print_exc() # Return error message and None for image return f"Error during CLIP processing: {e}", None def run_fastsam_segmentation(image_pil: Image.Image, conf_threshold: float = 0.4, iou_threshold: float = 0.9): # Add input type check if not isinstance(image_pil, Image.Image): print(f"FastSAM input is not a PIL Image, type: {type(image_pil)}") if isinstance(image_pil, np.ndarray): try: image_pil = Image.fromarray(image_pil) print("Converted numpy input to PIL Image for FastSAM.") except Exception as e: print(f"Failed to convert numpy array to PIL Image: {e}") # Return None for image on error return None, "Error: Invalid image input format." # Return tuple for consistency else: # Return None for image on error return None, "Error: Please provide a valid image." # Return tuple # Ensure model is loaded if not load_fastsam_model() or not fastsam_lib_imported or FastSAMPrompt is None: # Return None for image on critical error return None, "Error: FastSAM not loaded or library unavailable." print(f"Running FastSAM 'segment everything' with conf={conf_threshold}, iou={iou_threshold}...") output_image = None # Initialize output image status_message = "Processing..." # Initial status try: # Ensure image is RGB if image_pil.mode != "RGB": print(f"Converting image from {image_pil.mode} to RGB for FastSAM.") image_pil_rgb = image_pil.convert("RGB") else: image_pil_rgb = image_pil # Convert PIL Image to NumPy array (RGB) image_np_rgb = np.array(image_pil_rgb) print(f"Input image shape for FastSAM: {image_np_rgb.shape}") # Run FastSAM model # Make sure the arguments match what FastSAM expects everything_results = fastsam_model( image_np_rgb, device=DEVICE, retina_masks=True, imgsz=640, # Or another size FastSAM supports conf=conf_threshold, iou=iou_threshold, verbose=True # Keep verbose for debugging ) # Check if results are valid before creating prompt if everything_results is None or not isinstance(everything_results, list) or len(everything_results) == 0: print("FastSAM model returned None or empty results.") # Return original image and status return image_pil, "FastSAM did not return valid results." # Results might be in a different format, inspect 'everything_results' print(f"Type of everything_results: {type(everything_results)}") print(f"Length of everything_results: {len(everything_results)}") if len(everything_results) > 0: print(f"Type of first element: {type(everything_results[0])}") # Try to access potential attributes like 'masks' if it's an object if hasattr(everything_results[0], 'masks') and everything_results[0].masks is not None: print(f"Masks found in results object, shape: {everything_results[0].masks.data.shape}") else: print("First result element does not have 'masks' attribute or it's None.") # Process results with FastSAMPrompt # Ensure FastSAMPrompt class is available if FastSAMPrompt is None: print("FastSAMPrompt class is not available.") return image_pil, "Error: FastSAMPrompt class not loaded." prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE) ann = prompt_process.everything_prompt() # Get all annotations # Check annotation format - Adjust based on actual FastSAM output structure # Assuming 'ann' is a list and the first element is a dictionary containing masks masks = None if isinstance(ann, list) and len(ann) > 0 and isinstance(ann[0], dict) and 'masks' in ann[0]: mask_tensor = ann[0]['masks'] if mask_tensor is not None and mask_tensor.numel() > 0: # Check if tensor is not None and not empty masks = mask_tensor.cpu().numpy() print(f"Found {len(masks)} masks with shape: {masks.shape}") else: print("Annotation 'masks' tensor is None or empty.") else: print(f"No masks found or annotation format unexpected. ann type: {type(ann)}") if isinstance(ann, list) and len(ann) > 0: print(f"First element of ann: {ann[0]}") # Prepare output image (start with original) output_image = image_pil.copy() # Draw masks if found if masks is not None and len(masks) > 0: # Ensure output_image is RGBA for compositing overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(overlay) for i, mask in enumerate(masks): # Ensure mask is boolean/binary before converting binary_mask = (mask > 0) # Use threshold 0 for binary mask from FastSAM output mask_uint8 = binary_mask.astype(np.uint8) * 255 if mask_uint8.max() == 0: # Skip empty masks # print(f"Skipping empty mask {i}") continue color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 180) # RGBA color try: mask_image = Image.fromarray(mask_uint8, mode='L') # Grayscale mask # Draw the mask onto the overlay draw.bitmap((0, 0), mask_image, fill=color) except Exception as draw_err: print(f"Error drawing mask {i}: {draw_err}") traceback.print_exc() continue # Skip this mask # Composite the overlay onto the image try: output_image_rgba = output_image.convert('RGBA') output_image_composited = Image.alpha_composite(output_image_rgba, overlay) output_image = output_image_composited.convert('RGB') # Convert back to RGB for Gradio status_message = f"Segmentation complete. Found {len(masks)} masks." print("Mask drawing and compositing finished.") except Exception as comp_err: print(f"Error during alpha compositing: {comp_err}") traceback.print_exc() output_image = image_pil # Fallback to original image status_message = "Error during mask visualization." else: print("No masks detected or processed for 'segment everything' mode.") status_message = "No segments found or processed." output_image = image_pil # Return original image if no masks # Save for debugging before returning if output_image: try: debug_path = "debug_fastsam_everything_output.png" output_image.save(debug_path) print(f"Saved debug output to {debug_path}") except Exception as save_err: print(f"Failed to save debug image: {save_err}") return output_image, status_message # Return image and status message except Exception as e: print(f"Error during FastSAM 'everything' processing: {e}") traceback.print_exc() # Return original image and error message in case of failure return image_pil, f"Error during processing: {e}" def run_text_prompted_segmentation(image_pil: Image.Image, text_prompts: str, conf_threshold: float = 0.4, iou_threshold: float = 0.9): # Add input type check if not isinstance(image_pil, Image.Image): print(f"FastSAM Text input is not a PIL Image, type: {type(image_pil)}") if isinstance(image_pil, np.ndarray): try: image_pil = Image.fromarray(image_pil) print("Converted numpy input to PIL Image for FastSAM Text.") except Exception as e: print(f"Failed to convert numpy array to PIL Image: {e}") return None, "Error: Invalid image input format." else: return None, "Error: Please provide a valid image." # Ensure model is loaded if not load_fastsam_model() or not fastsam_lib_imported or FastSAMPrompt is None: return image_pil, "Error: FastSAM Model not loaded or library unavailable." # Return original image on load fail if not text_prompts: return image_pil, "Please enter text prompts (e.g., 'person, dog')." prompts = [p.strip() for p in text_prompts.split(',') if p.strip()] if not prompts: return image_pil, "No valid text prompts entered." print(f"Running FastSAM text-prompted segmentation for: {prompts} with conf={conf_threshold}, iou={iou_threshold}") output_image = None status_message = "Processing..." try: # Ensure image is RGB if image_pil.mode != "RGB": print(f"Converting image from {image_pil.mode} to RGB for FastSAM.") image_pil_rgb = image_pil.convert("RGB") else: image_pil_rgb = image_pil image_np_rgb = np.array(image_pil_rgb) print(f"Input image shape for FastSAM Text: {image_np_rgb.shape}") # Run FastSAM once to get all potential segments everything_results = fastsam_model( image_np_rgb, device=DEVICE, retina_masks=True, imgsz=640, # Use consistent args conf=conf_threshold, iou=iou_threshold, verbose=True ) # Check results if everything_results is None or not isinstance(everything_results, list) or len(everything_results) == 0: print("FastSAM model returned None or empty results for text prompt base.") return image_pil, "FastSAM did not return base results." # Initialize FastSAMPrompt if FastSAMPrompt is None: print("FastSAMPrompt class is not available.") return image_pil, "Error: FastSAMPrompt class not loaded." prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE) all_matching_masks = [] found_prompts_details = [] # Store details like 'prompt (count)' # Process each text prompt for text in prompts: print(f" Processing prompt: '{text}'") # Get annotation for the specific text prompt ann = prompt_process.text_prompt(text=text) # Check annotation format and extract masks current_masks = None num_found = 0 # Adjust check based on actual structure of 'ann' for text_prompt if isinstance(ann, list) and len(ann) > 0 and isinstance(ann[0], dict) and 'masks' in ann[0]: mask_tensor = ann[0]['masks'] if mask_tensor is not None and mask_tensor.numel() > 0: current_masks = mask_tensor.cpu().numpy() num_found = len(current_masks) print(f" Found {num_found} mask(s) for '{text}'. Shape: {current_masks.shape}") all_matching_masks.extend(current_masks) # Add found masks to the list else: print(f" Annotation 'masks' tensor is None or empty for '{text}'.") else: print(f" No masks found or annotation format unexpected for '{text}'. ann type: {type(ann)}") if isinstance(ann, list) and len(ann) > 0: print(f" First element of ann for '{text}': {ann[0]}") found_prompts_details.append(f"{text} ({num_found})") # Record count for status # Prepare output image output_image = image_pil.copy() status_message = f"Results: {', '.join(found_prompts_details)}" if found_prompts_details else "No matches found for any prompt." # Draw all collected masks if any were found if all_matching_masks: print(f"Total masks collected across all prompts: {len(all_matching_masks)}") # Stack masks if needed (optional, can draw one by one) # masks_np = np.stack(all_matching_masks, axis=0) # print(f"Total masks stacked shape: {masks_np.shape}") overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(overlay) for i, mask in enumerate(all_matching_masks): # Iterate through collected masks binary_mask = (mask > 0) mask_uint8 = binary_mask.astype(np.uint8) * 255 if mask_uint8.max() == 0: continue # Skip empty masks # Assign a unique color per mask or per prompt (using random here) color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 180) try: mask_image = Image.fromarray(mask_uint8, mode='L') draw.bitmap((0, 0), mask_image, fill=color) except Exception as draw_err: print(f"Error drawing collected mask {i}: {draw_err}") traceback.print_exc() continue # Composite the overlay try: output_image_rgba = output_image.convert('RGBA') output_image_composited = Image.alpha_composite(output_image_rgba, overlay) output_image = output_image_composited.convert('RGB') print("Text prompt mask drawing and compositing finished.") except Exception as comp_err: print(f"Error during alpha compositing for text prompts: {comp_err}") traceback.print_exc() output_image = image_pil # Fallback status_message += " (Error during visualization)" else: print("No matching masks found for any text prompt.") # status_message is already set # Save for debugging if output_image: try: debug_path = "debug_fastsam_text_output.png" output_image.save(debug_path) print(f"Saved debug output to {debug_path}") except Exception as save_err: print(f"Failed to save debug image: {save_err}") return output_image, status_message except Exception as e: print(f"Error during FastSAM text-prompted processing: {e}") traceback.print_exc() # Return original image and error message return image_pil, f"Error during processing: {e}" # --- Gradio Interface --- print("Attempting to preload models...") load_clip_model() # Preload CLIP load_fastsam_model() # Preload FastSAM print("Preloading finished (check logs above for errors).") # --- Gradio Interface Definition --- # (Your Gradio Blocks code remains largely the same, but ensure the outputs match the function returns) # --- Gradio Interface --- # ... (imports and functions) ... with gr.Blocks(theme=gr.themes.Soft()) as demo: # START of the block gr.Markdown("# CLIP & FastSAM Demo") # ... other UI elements ... with gr.Tabs(): with gr.TabItem("CLIP Zero-Shot Classification"): gr.Markdown("Upload an image and provide comma-separated labels...") with gr.Row(): with gr.Column(scale=1): clip_input_image = gr.Image(type="pil", label="Input Image") clip_text_labels = gr.Textbox(label="Comma-Separated Labels", placeholder="e.g., astronaut, moon") # DEFINE the button clip_button = gr.Button("Run CLIP Classification", variant="primary") with gr.Column(scale=1): clip_output_label = gr.Label(label="Classification Probabilities") clip_output_image_display = gr.Image(type="pil", label="Input Image Preview") # ATTACH the click handler *inside* the block, after the button is defined clip_button.click( run_clip_zero_shot, inputs=[clip_input_image, clip_text_labels], outputs=[clip_output_label, clip_output_image_display] ) # ... CLIP examples ... with gr.TabItem("FastSAM Segment Everything"): # ... FastSAM Everything UI elements ... fastsam_button_all = gr.Button(...) # Define button # Attach click handler *inside* the block fastsam_button_all.click( run_fastsam_segmentation, inputs=[...], outputs=[...] ) # ... FastSAM Everything examples ... with gr.TabItem("Text-Prompted Segmentation"): # ... Text-Prompted UI elements ... prompt_button = gr.Button(...) # Define button # Attach click handler *inside* the block prompt_button.click( run_text_prompted_segmentation, inputs=[...], outputs=[...] ) # ... Text-Prompted examples ... # The `with` block ends here. # --- Example File Download (This is correctly outside the block) --- # ... download logic ... # --- Launch App (This is correctly outside the block) --- if __name__ == "__main__": print("Launching Gradio Demo...") demo.launch(debug=True)