Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Upload app.py
Browse files
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,612 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torchvision
         
     | 
| 3 | 
         
            +
            from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
         
     | 
| 4 | 
         
            +
            from transformers import DetrImageProcessor, DetrForObjectDetection
         
     | 
| 5 | 
         
            +
            from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
         
     | 
| 6 | 
         
            +
            from PIL import Image
         
     | 
| 7 | 
         
            +
            import numpy as np
         
     | 
| 8 | 
         
            +
            import matplotlib.pyplot as plt
         
     | 
| 9 | 
         
            +
            import matplotlib.patches as patches
         
     | 
| 10 | 
         
            +
            import gradio as gr
         
     | 
| 11 | 
         
            +
            import os
         
     | 
| 12 | 
         
            +
            import io
         
     | 
| 13 | 
         
            +
            import uuid
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            # Load Faster R-CNN model with proper weight assignment
         
     | 
| 16 | 
         
            +
            frcnn_weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
         
     | 
| 17 | 
         
            +
            frcnn_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None, progress=True)
         
     | 
| 18 | 
         
            +
            state_dict = torch.hub.load_state_dict_from_url(frcnn_weights.url, progress=True, map_location=torch.device('cpu'))
         
     | 
| 19 | 
         
            +
            frcnn_model.load_state_dict(state_dict, strict=False)
         
     | 
| 20 | 
         
            +
            frcnn_model.eval()
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            # Load DETR model and processor
         
     | 
| 23 | 
         
            +
            detr_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
         
     | 
| 24 | 
         
            +
            detr_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            # Load Mask R-CNN model
         
     | 
| 27 | 
         
            +
            maskrcnn_model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
         
     | 
| 28 | 
         
            +
            maskrcnn_model.eval()
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            # Load Mask2Former model and processor
         
     | 
| 31 | 
         
            +
            mask2former_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance")
         
     | 
| 32 | 
         
            +
            mask2former_model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-small-coco-instance")
         
     | 
| 33 | 
         
            +
            mask2former_model.eval()
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            # COCO class names for Faster R-CNN and Mask R-CNN
         
     | 
| 36 | 
         
            +
            COCO_INSTANCE_CATEGORY_NAMES = [
         
     | 
| 37 | 
         
            +
                '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
         
     | 
| 38 | 
         
            +
                'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
         
     | 
| 39 | 
         
            +
                'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
         
     | 
| 40 | 
         
            +
                'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
         
     | 
| 41 | 
         
            +
                'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
         
     | 
| 42 | 
         
            +
                'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
         
     | 
| 43 | 
         
            +
                'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
         
     | 
| 44 | 
         
            +
                'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
         
     | 
| 45 | 
         
            +
                'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
         
     | 
| 46 | 
         
            +
                'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
         
     | 
| 47 | 
         
            +
                'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
         
     | 
| 48 | 
         
            +
                'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
         
     | 
| 49 | 
         
            +
            ]
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            # Mask2Former label map
         
     | 
| 52 | 
         
            +
            MASK2FORMER_COCO_NAMES = mask2former_model.config.id2label if hasattr(mask2former_model.config, "id2label") else {str(i): str(i) for i in range(133)}
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            def detect_objects_frcnn(image, threshold=0.5):
         
     | 
| 55 | 
         
            +
                """Run Faster R-CNN detection."""
         
     | 
| 56 | 
         
            +
                if image is None:
         
     | 
| 57 | 
         
            +
                    blank_img = Image.new('RGB', (400, 400), color='white')
         
     | 
| 58 | 
         
            +
                    plt.figure(figsize=(10, 10))
         
     | 
| 59 | 
         
            +
                    plt.imshow(blank_img)
         
     | 
| 60 | 
         
            +
                    plt.text(0.5, 0.5, "No image provided", horizontalalignment='center', verticalalignment='center',
         
     | 
| 61 | 
         
            +
                             transform=plt.gca().transAxes, fontsize=20)
         
     | 
| 62 | 
         
            +
                    plt.axis('off')
         
     | 
| 63 | 
         
            +
                    output_path = f"frcnn_blank_output_{uuid.uuid4()}.png"
         
     | 
| 64 | 
         
            +
                    plt.savefig(output_path)
         
     | 
| 65 | 
         
            +
                    plt.close()
         
     | 
| 66 | 
         
            +
                    return output_path, 0
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                try:
         
     | 
| 69 | 
         
            +
                    threshold = float(threshold) if threshold is not None else 0.5
         
     | 
| 70 | 
         
            +
                    image = image.convert('RGB')
         
     | 
| 71 | 
         
            +
                    img_array = np.array(image).astype(np.float32) / 255.0
         
     | 
| 72 | 
         
            +
                    transform = frcnn_weights.transforms()
         
     | 
| 73 | 
         
            +
                    image_tensor = transform(Image.fromarray((img_array * 255).astype(np.uint8))).unsqueeze(0)
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    with torch.no_grad():
         
     | 
| 76 | 
         
            +
                        prediction = frcnn_model(image_tensor)[0]
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                    boxes = prediction['boxes'].cpu().numpy()
         
     | 
| 79 | 
         
            +
                    labels = prediction['labels'].cpu().numpy()
         
     | 
| 80 | 
         
            +
                    scores = prediction['scores'].cpu().numpy()
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    valid_detections = sum(1 for score in scores if score >= threshold)
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    image_np = np.array(image)
         
     | 
| 85 | 
         
            +
                    plt.figure(figsize=(10, 10))
         
     | 
| 86 | 
         
            +
                    plt.imshow(image_np)
         
     | 
| 87 | 
         
            +
                    ax = plt.gca()
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    for box, label, score in zip(boxes, labels, scores):
         
     | 
| 90 | 
         
            +
                        if score >= threshold:
         
     | 
| 91 | 
         
            +
                            x1, y1, x2, y2 = box
         
     | 
| 92 | 
         
            +
                            ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color='red', linewidth=2))
         
     | 
| 93 | 
         
            +
                            class_name = COCO_INSTANCE_CATEGORY_NAMES[label]
         
     | 
| 94 | 
         
            +
                            ax.text(x1, y1, f'{class_name}: {score:.2f}', bbox=dict(facecolor='yellow', alpha=0.5), fontsize=12, color='black')
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    plt.axis('off')
         
     | 
| 97 | 
         
            +
                    plt.tight_layout()
         
     | 
| 98 | 
         
            +
                    output_path = f"frcnn_output_{uuid.uuid4()}.png"
         
     | 
| 99 | 
         
            +
                    plt.savefig(output_path)
         
     | 
| 100 | 
         
            +
                    plt.close()
         
     | 
| 101 | 
         
            +
                    return output_path, valid_detections
         
     | 
| 102 | 
         
            +
                except Exception as e:
         
     | 
| 103 | 
         
            +
                    error_img = Image.new('RGB', (400, 400), color='white')
         
     | 
| 104 | 
         
            +
                    plt.figure(figsize=(10, 10))
         
     | 
| 105 | 
         
            +
                    plt.imshow(error_img)
         
     | 
| 106 | 
         
            +
                    plt.text(0.5, 0.5, f"Error: {str(e)}", horizontalalignment='center', verticalalignment='center',
         
     | 
| 107 | 
         
            +
                             transform=plt.gca().transAxes, fontsize=12, wrap=True)
         
     | 
| 108 | 
         
            +
                    plt.axis('off')
         
     | 
| 109 | 
         
            +
                    error_path = f"frcnn_error_output_{uuid.uuid4()}.png"
         
     | 
| 110 | 
         
            +
                    plt.savefig(error_path)
         
     | 
| 111 | 
         
            +
                    plt.close()
         
     | 
| 112 | 
         
            +
                    return error_path, 0
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
            def detect_objects_detr(image, threshold=0.9):
         
     | 
| 115 | 
         
            +
                """Run DETR detection."""
         
     | 
| 116 | 
         
            +
                if image is None:
         
     | 
| 117 | 
         
            +
                    blank_img = Image.new('RGB', (400, 400), color='white')
         
     | 
| 118 | 
         
            +
                    fig, ax = plt.subplots(1, figsize=(10, 10))
         
     | 
| 119 | 
         
            +
                    ax.imshow(blank_img)
         
     | 
| 120 | 
         
            +
                    ax.text(0.5, 0.5, "No image provided", horizontalalignment='center', verticalalignment='center',
         
     | 
| 121 | 
         
            +
                            transform=ax.transAxes, fontsize=20)
         
     | 
| 122 | 
         
            +
                    plt.axis('off')
         
     | 
| 123 | 
         
            +
                    output_path = f"detr_blank_output_{uuid.uuid4()}.png"
         
     | 
| 124 | 
         
            +
                    plt.savefig(output_path)
         
     | 
| 125 | 
         
            +
                    plt.close(fig)
         
     | 
| 126 | 
         
            +
                    return output_path, 0
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                try:
         
     | 
| 129 | 
         
            +
                    image = image.convert('RGB')
         
     | 
| 130 | 
         
            +
                    inputs = detr_processor(images=image, return_tensors="pt")
         
     | 
| 131 | 
         
            +
                    outputs = detr_model(**inputs)
         
     | 
| 132 | 
         
            +
                    target_sizes = torch.tensor([image.size[::-1]])
         
     | 
| 133 | 
         
            +
                    results = detr_processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=threshold)[0]
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    valid_detections = len(results["scores"])
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    fig, ax = plt.subplots(1, figsize=(10, 10))
         
     | 
| 138 | 
         
            +
                    ax.imshow(image)
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
         
     | 
| 141 | 
         
            +
                        xmin, ymin, xmax, ymax = box.tolist()
         
     | 
| 142 | 
         
            +
                        ax.add_patch(patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor='red', facecolor='none'))
         
     | 
| 143 | 
         
            +
                        ax.text(xmin, ymin, f"{detr_model.config.id2label[label.item()]}: {round(score.item(), 2)}",
         
     | 
| 144 | 
         
            +
                                bbox=dict(facecolor='yellow', alpha=0.5), fontsize=8)
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                    plt.axis('off')
         
     | 
| 147 | 
         
            +
                    output_path = f"detr_output_{uuid.uuid4()}.png"
         
     | 
| 148 | 
         
            +
                    plt.savefig(output_path)
         
     | 
| 149 | 
         
            +
                    plt.close(fig)
         
     | 
| 150 | 
         
            +
                    return output_path, valid_detections
         
     | 
| 151 | 
         
            +
                except Exception as e:
         
     | 
| 152 | 
         
            +
                    error_img = Image.new('RGB', (400, 400), color='white')
         
     | 
| 153 | 
         
            +
                    fig, ax = plt.subplots(1, figsize=(10, 10))
         
     | 
| 154 | 
         
            +
                    ax.imshow(error_img)
         
     | 
| 155 | 
         
            +
                    ax.text(0.5, 0.5, f"Error: {str(e)}", horizontalalignment='center', verticalalignment='center',
         
     | 
| 156 | 
         
            +
                            transform=ax.transAxes, fontsize=12, wrap=True)
         
     | 
| 157 | 
         
            +
                    plt.axis('off')
         
     | 
| 158 | 
         
            +
                    error_path = f"detr_error_output_{uuid.uuid4()}.png"
         
     | 
| 159 | 
         
            +
                    plt.savefig(error_path)
         
     | 
| 160 | 
         
            +
                    plt.close(fig)
         
     | 
| 161 | 
         
            +
                    return error_path, 0
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
            def detect_objects_maskrcnn(image, threshold=0.5):
         
     | 
| 164 | 
         
            +
                """Run Mask R-CNN detection and segmentation."""
         
     | 
| 165 | 
         
            +
                if image is None:
         
     | 
| 166 | 
         
            +
                    blank_img = Image.new('RGB', (400, 400), color='white')
         
     | 
| 167 | 
         
            +
                    plt.figure(figsize=(10, 10))
         
     | 
| 168 | 
         
            +
                    plt.imshow(blank_img)
         
     | 
| 169 | 
         
            +
                    plt.text(0.5, 0.5, "No image provided", horizontalalignment='center', verticalalignment='center',
         
     | 
| 170 | 
         
            +
                             transform=plt.gca().transAxes, fontsize=20)
         
     | 
| 171 | 
         
            +
                    plt.axis('off')
         
     | 
| 172 | 
         
            +
                    output_path = f"maskrcnn_blank_output_{uuid.uuid4()}.png"
         
     | 
| 173 | 
         
            +
                    plt.savefig(output_path)
         
     | 
| 174 | 
         
            +
                    plt.close()
         
     | 
| 175 | 
         
            +
                    return output_path, 0
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                try:
         
     | 
| 178 | 
         
            +
                    image = image.convert('RGB')
         
     | 
| 179 | 
         
            +
                    transform = torchvision.transforms.ToTensor()
         
     | 
| 180 | 
         
            +
                    img_tensor = transform(image).unsqueeze(0)
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                    with torch.no_grad():
         
     | 
| 183 | 
         
            +
                        output = maskrcnn_model(img_tensor)[0]
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                    masks = output['masks']
         
     | 
| 186 | 
         
            +
                    boxes = output['boxes'].cpu().numpy()
         
     | 
| 187 | 
         
            +
                    labels = output['labels'].cpu().numpy()
         
     | 
| 188 | 
         
            +
                    scores = output['scores'].cpu().numpy()
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                    valid_detections = sum(1 for score in scores if score >= threshold)
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                    image_np = np.array(image).copy()
         
     | 
| 193 | 
         
            +
                    fig, ax = plt.subplots(1, figsize=(10, 10))
         
     | 
| 194 | 
         
            +
                    ax.imshow(image_np)
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                    for i in range(len(masks)):
         
     | 
| 197 | 
         
            +
                        if scores[i] >= threshold:
         
     | 
| 198 | 
         
            +
                            mask = masks[i, 0].cpu().numpy()
         
     | 
| 199 | 
         
            +
                            mask = mask > 0.5
         
     | 
| 200 | 
         
            +
                            color = np.random.rand(3)
         
     | 
| 201 | 
         
            +
                            colored_mask = np.zeros_like(image_np, dtype=np.uint8)
         
     | 
| 202 | 
         
            +
                            for c in range(3):
         
     | 
| 203 | 
         
            +
                                colored_mask[:, :, c] = mask * int(color[c] * 255)
         
     | 
| 204 | 
         
            +
                            image_np = np.where(mask[:, :, None], 0.5 * image_np + 0.5 * colored_mask, image_np).astype(np.uint8)
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                            x1, y1, x2, y2 = boxes[i]
         
     | 
| 207 | 
         
            +
                            ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color=color, linewidth=2))
         
     | 
| 208 | 
         
            +
                            label = COCO_INSTANCE_CATEGORY_NAMES[labels[i]]
         
     | 
| 209 | 
         
            +
                            ax.text(x1, y1, f"{label}: {scores[i]:.2f}", bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10)
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                    ax.imshow(image_np)
         
     | 
| 212 | 
         
            +
                    ax.axis('off')
         
     | 
| 213 | 
         
            +
                    output_path = f"maskrcnn_output_{uuid.uuid4()}.png"
         
     | 
| 214 | 
         
            +
                    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
         
     | 
| 215 | 
         
            +
                    plt.close()
         
     | 
| 216 | 
         
            +
                    return output_path, valid_detections
         
     | 
| 217 | 
         
            +
                except Exception as e:
         
     | 
| 218 | 
         
            +
                    error_img = Image.new('RGB', (400, 400), color='white')
         
     | 
| 219 | 
         
            +
                    plt.figure(figsize=(10, 10))
         
     | 
| 220 | 
         
            +
                    plt.imshow(error_img)
         
     | 
| 221 | 
         
            +
                    plt.text(0.5, 0.5, f"Error: {str(e)}", horizontalalignment='center', verticalalignment='center',
         
     | 
| 222 | 
         
            +
                             transform=plt.gca().transAxes, fontsize=12, wrap=True)
         
     | 
| 223 | 
         
            +
                    plt.axis('off')
         
     | 
| 224 | 
         
            +
                    error_path = f"maskrcnn_error_output_{uuid.uuid4()}.png"
         
     | 
| 225 | 
         
            +
                    plt.savefig(error_path)
         
     | 
| 226 | 
         
            +
                    plt.close()
         
     | 
| 227 | 
         
            +
                    return error_path, 0
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
            def detect_objects_mask2former(image, threshold=0.5):
         
     | 
| 230 | 
         
            +
                """Run Mask2Former detection and segmentation."""
         
     | 
| 231 | 
         
            +
                if image is None:
         
     | 
| 232 | 
         
            +
                    blank_img = Image.new('RGB', (400, 400), color='white')
         
     | 
| 233 | 
         
            +
                    plt.figure(figsize=(10, 10))
         
     | 
| 234 | 
         
            +
                    plt.imshow(blank_img)
         
     | 
| 235 | 
         
            +
                    plt.text(0.5, 0.5, "No image provided", horizontalalignment='center', verticalalignment='center',
         
     | 
| 236 | 
         
            +
                             transform=plt.gca().transAxes, fontsize=20)
         
     | 
| 237 | 
         
            +
                    plt.axis('off')
         
     | 
| 238 | 
         
            +
                    output_path = f"mask2former_blank_output_{uuid.uuid4()}.png"
         
     | 
| 239 | 
         
            +
                    plt.savefig(output_path)
         
     | 
| 240 | 
         
            +
                    plt.close()
         
     | 
| 241 | 
         
            +
                    return output_path, 0
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                try:
         
     | 
| 244 | 
         
            +
                    image = image.convert('RGB')
         
     | 
| 245 | 
         
            +
                    inputs = mask2former_processor(images=image, return_tensors="pt")
         
     | 
| 246 | 
         
            +
                    with torch.no_grad():
         
     | 
| 247 | 
         
            +
                        outputs = mask2former_model(**inputs)
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                    results = mask2former_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
         
     | 
| 250 | 
         
            +
                    segmentation_map = results["segmentation"].cpu().numpy()
         
     | 
| 251 | 
         
            +
                    segments_info = results["segments_info"]
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                    valid_detections = sum(1 for segment in segments_info if segment.get("score", 1.0) >= threshold)
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                    image_np = np.array(image).copy()
         
     | 
| 256 | 
         
            +
                    overlay = image_np.copy()
         
     | 
| 257 | 
         
            +
                    fig, ax = plt.subplots(1, figsize=(10, 10))
         
     | 
| 258 | 
         
            +
                    ax.imshow(image_np)
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
                    for segment in segments_info:
         
     | 
| 261 | 
         
            +
                        score = segment.get("score", 1.0)
         
     | 
| 262 | 
         
            +
                        if score < threshold:
         
     | 
| 263 | 
         
            +
                            continue
         
     | 
| 264 | 
         
            +
                        segment_id = segment["id"]
         
     | 
| 265 | 
         
            +
                        label_id = segment["label_id"]
         
     | 
| 266 | 
         
            +
                        mask = segmentation_map == segment_id
         
     | 
| 267 | 
         
            +
                        color = np.random.rand(3)
         
     | 
| 268 | 
         
            +
                        overlay[mask] = (overlay[mask] * 0.5 + np.array(color) * 255 * 0.5).astype(np.uint8)
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                        y_indices, x_indices = np.where(mask)
         
     | 
| 271 | 
         
            +
                        if len(x_indices) == 0 or len(y_indices) == 0:
         
     | 
| 272 | 
         
            +
                            continue
         
     | 
| 273 | 
         
            +
                        x1, x2 = x_indices.min(), x_indices.max()
         
     | 
| 274 | 
         
            +
                        y1, y2 = y_indices.min(), y_indices.max()
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
                        label_name = MASK2FORMER_COCO_NAMES.get(str(label_id), str(label_id))
         
     | 
| 277 | 
         
            +
                        ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color=color, linewidth=2))
         
     | 
| 278 | 
         
            +
                        ax.text(x1, y1, f"{label_name}: {score:.2f}", bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10)
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                    ax.imshow(overlay)
         
     | 
| 281 | 
         
            +
                    ax.axis('off')
         
     | 
| 282 | 
         
            +
                    output_path = f"mask2former_output_{uuid.uuid4()}.png"
         
     | 
| 283 | 
         
            +
                    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
         
     | 
| 284 | 
         
            +
                    plt.close()
         
     | 
| 285 | 
         
            +
                    return output_path, valid_detections
         
     | 
| 286 | 
         
            +
                except Exception as e:
         
     | 
| 287 | 
         
            +
                    error_img = Image.new('RGB', (400, 400), color='white')
         
     | 
| 288 | 
         
            +
                    plt.figure(figsize=(10, 10))
         
     | 
| 289 | 
         
            +
                    plt.imshow(error_img)
         
     | 
| 290 | 
         
            +
                    plt.text(0.5, 0.5, f"Error: {str(e)}", horizontalalignment='center', verticalalignment='center',
         
     | 
| 291 | 
         
            +
                             transform=plt.gca().transAxes, fontsize=12, wrap=True)
         
     | 
| 292 | 
         
            +
                    plt.axis('off')
         
     | 
| 293 | 
         
            +
                    error_path = f"mask2former_error_output_{uuid.uuid4()}.png"
         
     | 
| 294 | 
         
            +
                    plt.savefig(error_path)
         
     | 
| 295 | 
         
            +
                    plt.close()
         
     | 
| 296 | 
         
            +
                    return error_path, 0
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
            def update_model_choices(category):
         
     | 
| 299 | 
         
            +
                """Update model choices for prediction radio buttons based on selected category."""
         
     | 
| 300 | 
         
            +
                if category == "Object Detection":
         
     | 
| 301 | 
         
            +
                    return gr.update(choices=["ConvNet (Faster R-CNN)", "Transformer (DETR)"], value=None, visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
         
     | 
| 302 | 
         
            +
                elif category == "Object Segmentation":
         
     | 
| 303 | 
         
            +
                    return gr.update(choices=["ConvNet (Mask R-CNN)", "Transformer (Mask2Former)"], value=None, visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
         
     | 
| 304 | 
         
            +
                return gr.update(choices=[], visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
            def analyze_performance(image, category, user_opinion, frcnn_threshold=0.5, detr_threshold=0.9, maskrcnn_threshold=0.5, mask2former_threshold=0.5):
         
     | 
| 307 | 
         
            +
                """Analyze and compare model performance for all models in the selected category."""
         
     | 
| 308 | 
         
            +
                if image is None:
         
     | 
| 309 | 
         
            +
                    return "Please upload an image first.", None, None, None, None, "No analysis available."
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
                frcnn_result = None
         
     | 
| 312 | 
         
            +
                detr_result = None
         
     | 
| 313 | 
         
            +
                maskrcnn_result = None
         
     | 
| 314 | 
         
            +
                mask2former_result = None
         
     | 
| 315 | 
         
            +
                frcnn_count = 0
         
     | 
| 316 | 
         
            +
                detr_count = 0
         
     | 
| 317 | 
         
            +
                maskrcnn_count = 0
         
     | 
| 318 | 
         
            +
                mask2former_count = 0
         
     | 
| 319 | 
         
            +
             
     | 
| 320 | 
         
            +
                if category == "Object Detection":
         
     | 
| 321 | 
         
            +
                    frcnn_result, frcnn_count = detect_objects_frcnn(image, frcnn_threshold)
         
     | 
| 322 | 
         
            +
                    detr_result, detr_count = detect_objects_detr(image, detr_threshold)
         
     | 
| 323 | 
         
            +
                elif category == "Object Segmentation":
         
     | 
| 324 | 
         
            +
                    maskrcnn_result, maskrcnn_count = detect_objects_maskrcnn(image, maskrcnn_threshold)
         
     | 
| 325 | 
         
            +
                    mask2former_result, mask2former_count = detect_objects_mask2former(image, mask2former_threshold)
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
                # Analyze performance
         
     | 
| 328 | 
         
            +
                counts = {}
         
     | 
| 329 | 
         
            +
                model_mapping = {
         
     | 
| 330 | 
         
            +
                    "ConvNet (Faster R-CNN)": "ConvNet (Faster R-CNN)",
         
     | 
| 331 | 
         
            +
                    "Transformer (DETR)": "Transformer (DETR)",
         
     | 
| 332 | 
         
            +
                    "ConvNet (Mask R-CNN)": "ConvNet (Mask R-CNN)",
         
     | 
| 333 | 
         
            +
                    "Transformer (Mask2Former)": "Transformer (Mask2Former)"
         
     | 
| 334 | 
         
            +
                }
         
     | 
| 335 | 
         
            +
                if category == "Object Detection":
         
     | 
| 336 | 
         
            +
                    counts = {
         
     | 
| 337 | 
         
            +
                        "ConvNet (Faster R-CNN)": frcnn_count,
         
     | 
| 338 | 
         
            +
                        "Transformer (DETR)": detr_count
         
     | 
| 339 | 
         
            +
                    }
         
     | 
| 340 | 
         
            +
                elif category == "Object Segmentation":
         
     | 
| 341 | 
         
            +
                    counts = {
         
     | 
| 342 | 
         
            +
                        "ConvNet (Mask R-CNN)": maskrcnn_count,
         
     | 
| 343 | 
         
            +
                        "Transformer (Mask2Former)": mask2former_count
         
     | 
| 344 | 
         
            +
                    }
         
     | 
| 345 | 
         
            +
             
     | 
| 346 | 
         
            +
                max_count = max(counts.values())
         
     | 
| 347 | 
         
            +
                max_models = [model for model, count in counts.items() if count == max_count]
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
                if len(max_models) == 1:
         
     | 
| 350 | 
         
            +
                    analysis = f"Result: {max_models[0]} performed best, identifying {max_count} objects.\n\n"
         
     | 
| 351 | 
         
            +
                else:
         
     | 
| 352 | 
         
            +
                    analysis = f"Result: {', '.join(max_models)} performed equally well, each identifying {max_count} objects.\n\n"
         
     | 
| 353 | 
         
            +
             
     | 
| 354 | 
         
            +
                if user_opinion:
         
     | 
| 355 | 
         
            +
                    analysis += f"You predicted that {user_opinion} would perform best.\n"
         
     | 
| 356 | 
         
            +
                    if user_opinion in max_models:
         
     | 
| 357 | 
         
            +
                        analysis += f"Congratulations, your prediction was correct!\n"
         
     | 
| 358 | 
         
            +
                    else:
         
     | 
| 359 | 
         
            +
                        analysis += f"Your prediction was not correct. {user_opinion} identified {counts[user_opinion]} objects, while {', '.join(max_models)} performed best with {max_count} objects. Please try again with a new image.\n"
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
                if category == "Object Detection":
         
     | 
| 362 | 
         
            +
                    analysis += "\nConvNet (Faster R-CNN) is efficient and reliable for general object identification tasks. Transformer (DETR) excels in complex scenes by leveraging advanced context understanding."
         
     | 
| 363 | 
         
            +
                elif category == "Object Segmentation":
         
     | 
| 364 | 
         
            +
                    analysis += "\nConvNet (Mask R-CNN) provides precise object outlines for detailed analysis. Transformer (Mask2Former) often outperforms in complex scenes due to its advanced architecture."
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
                # Image-specific recommendation
         
     | 
| 367 | 
         
            +
                img_array = np.array(image)
         
     | 
| 368 | 
         
            +
                height, width = img_array.shape[:2]
         
     | 
| 369 | 
         
            +
                pixel_variance = np.var(img_array)
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
                if height * width > 1000 * 1000:
         
     | 
| 372 | 
         
            +
                    analysis += f"\n\nThis high-resolution image benefits from Transformer models, which excel in detailed and complex scenes."
         
     | 
| 373 | 
         
            +
                if pixel_variance > 1000:
         
     | 
| 374 | 
         
            +
                    analysis += f"\n\nThis image has high complexity. Transformer models often provide superior results in such cases."
         
     | 
| 375 | 
         
            +
                if height * width < 500 * 500:
         
     | 
| 376 | 
         
            +
                    analysis += f"\n\nFor smaller images, ConvNet models often deliver reliable results with lower computational demands."
         
     | 
| 377 | 
         
            +
                if category == "Object Segmentation" and max_count > 0:
         
     | 
| 378 | 
         
            +
                    analysis += "\n\nFor detailed outlining tasks, Transformer (Mask2Former) may be preferable for complex scenes due to its advanced design."
         
     | 
| 379 | 
         
            +
             
     | 
| 380 | 
         
            +
                # Enhanced result formatting
         
     | 
| 381 | 
         
            +
                if user_opinion and user_opinion in max_models:
         
     | 
| 382 | 
         
            +
                    celebration = "๐โจ"
         
     | 
| 383 | 
         
            +
                    analysis = analysis.replace("Congratulations", f"{celebration} EPIC WIN! {celebration}")
         
     | 
| 384 | 
         
            +
                    analysis = analysis.replace("!\n", "! ๐ฅณ\n")
         
     | 
| 385 | 
         
            +
                    analysis += "\n\n๐ You've mastered the AI showdown! ๐"
         
     | 
| 386 | 
         
            +
                elif user_opinion:
         
     | 
| 387 | 
         
            +
                    analysis = analysis.replace("try again", "try again ๐ช")
         
     | 
| 388 | 
         
            +
             
     | 
| 389 | 
         
            +
                # Convert to HTML with styling
         
     | 
| 390 | 
         
            +
                html_analysis = f"""
         
     | 
| 391 | 
         
            +
                <div class="{'celebrate' if user_opinion in max_models else ''}" style="margin: 15px 0;">
         
     | 
| 392 | 
         
            +
                    <h3 style='color: {"#4CAF50" if user_opinion in max_models else "#f44336"}; margin-bottom: 15px;'>
         
     | 
| 393 | 
         
            +
                        {"๐ " + max_models[0] + " Dominates!" if len(max_models) == 1 else "โ๏ธ Tie Battle!"}
         
     | 
| 394 | 
         
            +
                    </h3>
         
     | 
| 395 | 
         
            +
                    <div style="background: var(--background-fill-primary); padding: 20px; border-radius: 10px; 
         
     | 
| 396 | 
         
            +
                                white-space: pre-wrap; overflow-wrap: break-word; color: var(--text-color);">
         
     | 
| 397 | 
         
            +
                        {analysis}
         
     | 
| 398 | 
         
            +
                    </div>
         
     | 
| 399 | 
         
            +
                </div>
         
     | 
| 400 | 
         
            +
                """
         
     | 
| 401 | 
         
            +
                return "Analysis complete!", frcnn_result, detr_result, maskrcnn_result, mask2former_result, html_analysis
         
     | 
| 402 | 
         
            +
             
     | 
| 403 | 
         
            +
            # Create Gradio interface with enhanced design
         
     | 
| 404 | 
         
            +
            with gr.Blocks(title="AI Vision Showdown", theme=gr.themes.Default(primary_hue="emerald", secondary_hue="blue")) as app:
         
     | 
| 405 | 
         
            +
                gr.Markdown("""
         
     | 
| 406 | 
         
            +
                # ๐ฏ AI Vision Showdown: ConvNets vs Transformers
         
     | 
| 407 | 
         
            +
                ### ๐ค Battle of the algorithms! Upload an image and predict which AI will dominate!
         
     | 
| 408 | 
         
            +
                """)
         
     | 
| 409 | 
         
            +
                
         
     | 
| 410 | 
         
            +
                # Enhanced CSS
         
     | 
| 411 | 
         
            +
                gr.HTML("""
         
     | 
| 412 | 
         
            +
                <style>
         
     | 
| 413 | 
         
            +
                    @keyframes celebrate {
         
     | 
| 414 | 
         
            +
                        0% { transform: rotate(0deg); }
         
     | 
| 415 | 
         
            +
                        25% { transform: rotate(5deg); }
         
     | 
| 416 | 
         
            +
                        50% { transform: rotate(-5deg); }
         
     | 
| 417 | 
         
            +
                        75% { transform: rotate(5deg); }
         
     | 
| 418 | 
         
            +
                        100% { transform: rotate(0deg); }
         
     | 
| 419 | 
         
            +
                    }
         
     | 
| 420 | 
         
            +
                    .celebrate { animation: celebrate 0.5s ease-in-out; }
         
     | 
| 421 | 
         
            +
                    .battle-card {
         
     | 
| 422 | 
         
            +
                        border-radius: 15px;
         
     | 
| 423 | 
         
            +
                        padding: 20px;
         
     | 
| 424 | 
         
            +
                        margin: 10px 0;
         
     | 
| 425 | 
         
            +
                        background: var(--background-fill-primary);
         
     | 
| 426 | 
         
            +
                        border: 1px solid var(--border-color-primary);
         
     | 
| 427 | 
         
            +
                    }
         
     | 
| 428 | 
         
            +
                    .analysis-box {
         
     | 
| 429 | 
         
            +
                        background: var(--background-fill-secondary) !important;
         
     | 
| 430 | 
         
            +
                        color: var(--text-color) !important;
         
     | 
| 431 | 
         
            +
                        padding: 20px;
         
     | 
| 432 | 
         
            +
                        border-radius: 10px;
         
     | 
| 433 | 
         
            +
                        white-space: pre-wrap;
         
     | 
| 434 | 
         
            +
                        overflow-wrap: break-word;
         
     | 
| 435 | 
         
            +
                    }
         
     | 
| 436 | 
         
            +
                    .loading-status {
         
     | 
| 437 | 
         
            +
                        padding: 15px;
         
     | 
| 438 | 
         
            +
                        background: var(--background-fill-secondary);
         
     | 
| 439 | 
         
            +
                        border-radius: 8px;
         
     | 
| 440 | 
         
            +
                        margin: 10px 0;
         
     | 
| 441 | 
         
            +
                        text-align: center;
         
     | 
| 442 | 
         
            +
                        font-weight: bold;
         
     | 
| 443 | 
         
            +
                    }
         
     | 
| 444 | 
         
            +
                </style>
         
     | 
| 445 | 
         
            +
                """)
         
     | 
| 446 | 
         
            +
             
     | 
| 447 | 
         
            +
                # State variables
         
     | 
| 448 | 
         
            +
                image_state = gr.State(None)
         
     | 
| 449 | 
         
            +
                category_state = gr.State(None)
         
     | 
| 450 | 
         
            +
                loading_status = gr.HTML(visible=False)
         
     | 
| 451 | 
         
            +
             
     | 
| 452 | 
         
            +
                # Top Section: Inputs
         
     | 
| 453 | 
         
            +
                with gr.Row(variant="battle-card"):
         
     | 
| 454 | 
         
            +
                    with gr.Column(scale=1, min_width=300):
         
     | 
| 455 | 
         
            +
                        gr.Markdown("## ๐ค Image Upload Zone")
         
     | 
| 456 | 
         
            +
                        image_input = gr.Image(type="pil", label="Drag & Drop Your Challenge Image")
         
     | 
| 457 | 
         
            +
                        upload_button = gr.Button("๐ผ Upload Challenge Image", variant="primary")
         
     | 
| 458 | 
         
            +
             
     | 
| 459 | 
         
            +
                    with gr.Column(scale=1, min_width=300):
         
     | 
| 460 | 
         
            +
                        with gr.Group(visible=False) as prediction_selection:
         
     | 
| 461 | 
         
            +
                            gr.Markdown("## ๐ฎ Prediction Arena")
         
     | 
| 462 | 
         
            +
                            category_choice = gr.Radio(
         
     | 
| 463 | 
         
            +
                                choices=["Object Detection", "Object Segmentation"],
         
     | 
| 464 | 
         
            +
                                label="โ๏ธ Select Battle Ground",
         
     | 
| 465 | 
         
            +
                                value=None,
         
     | 
| 466 | 
         
            +
                                elem_classes="battle-card"
         
     | 
| 467 | 
         
            +
                            )
         
     | 
| 468 | 
         
            +
                            user_opinion = gr.Radio(
         
     | 
| 469 | 
         
            +
                                choices=[],
         
     | 
| 470 | 
         
            +
                                label="๐น Predict the Victor",
         
     | 
| 471 | 
         
            +
                                value=None,
         
     | 
| 472 | 
         
            +
                                visible=False,
         
     | 
| 473 | 
         
            +
                                elem_classes="battle-card"
         
     | 
| 474 | 
         
            +
                            )
         
     | 
| 475 | 
         
            +
                            
         
     | 
| 476 | 
         
            +
                            # Enhanced threshold controls
         
     | 
| 477 | 
         
            +
                            with gr.Accordion("๐๏ธ Advanced Battle Parameters", open=False):
         
     | 
| 478 | 
         
            +
                                frcnn_threshold = gr.Slider(
         
     | 
| 479 | 
         
            +
                                    minimum=0.0, maximum=1.0, value=0.5, step=0.05,
         
     | 
| 480 | 
         
            +
                                    label="Faster R-CNN Confidence (Speed Demon ๐๏ธ)",
         
     | 
| 481 | 
         
            +
                                    visible=False
         
     | 
| 482 | 
         
            +
                                )
         
     | 
| 483 | 
         
            +
                                detr_threshold = gr.Slider(
         
     | 
| 484 | 
         
            +
                                    minimum=0.0, maximum=1.0, value=0.9, step=0.05,
         
     | 
| 485 | 
         
            +
                                    label="DETR Confidence (Attention Master ๐)",
         
     | 
| 486 | 
         
            +
                                    visible=False
         
     | 
| 487 | 
         
            +
                                )
         
     | 
| 488 | 
         
            +
                                maskrcnn_threshold = gr.Slider(
         
     | 
| 489 | 
         
            +
                                    minimum=0.0, maximum=1.0, value=0.5, step=0.05,
         
     | 
| 490 | 
         
            +
                                    label="Mask R-CNN Confidence (Precision Expert โ๏ธ)",
         
     | 
| 491 | 
         
            +
                                    visible=False
         
     | 
| 492 | 
         
            +
                                )
         
     | 
| 493 | 
         
            +
                                mask2former_threshold = gr.Slider(
         
     | 
| 494 | 
         
            +
                                    minimum=0.0, maximum=1.0, value=0.5, step=0.05,
         
     | 
| 495 | 
         
            +
                                    label="Mask2Former Confidence (Transformer Champ ๐ค)",
         
     | 
| 496 | 
         
            +
                                    visible=False
         
     | 
| 497 | 
         
            +
                                )
         
     | 
| 498 | 
         
            +
             
     | 
| 499 | 
         
            +
                            detect_button = gr.Button("โ๏ธ Start Showdown", variant="primary")
         
     | 
| 500 | 
         
            +
             
     | 
| 501 | 
         
            +
                # Results Section
         
     | 
| 502 | 
         
            +
                with gr.Group(visible=False) as outputs_panel:
         
     | 
| 503 | 
         
            +
                    gr.Markdown("## ๐ Battle Results")
         
     | 
| 504 | 
         
            +
                    with gr.Tabs():
         
     | 
| 505 | 
         
            +
                        with gr.TabItem("Object Detection Warriors", visible=False) as detection_tab:
         
     | 
| 506 | 
         
            +
                            with gr.Row():
         
     | 
| 507 | 
         
            +
                                frcnn_result = gr.Image(type="filepath", label="๐ Faster R-CNN (ConvNet Champion)", elem_classes="battle-card")
         
     | 
| 508 | 
         
            +
                                detr_result = gr.Image(type="filepath", label="๐ง  DETR (Transformer Visionary)", elem_classes="battle-card")
         
     | 
| 509 | 
         
            +
                        
         
     | 
| 510 | 
         
            +
                        with gr.TabItem("Segmentation Gladiators", visible=False) as segmentation_tab:
         
     | 
| 511 | 
         
            +
                            with gr.Row():
         
     | 
| 512 | 
         
            +
                                maskrcnn_result = gr.Image(type="filepath", label="โ๏ธ Mask R-CNN (Pixel Perfect)", elem_classes="battle-card")
         
     | 
| 513 | 
         
            +
                                mask2former_result = gr.Image(type="filepath", label="๐ก๏ธ Mask2Former (Segmentation Master)", elem_classes="battle-card")
         
     | 
| 514 | 
         
            +
             
     | 
| 515 | 
         
            +
                # Analysis Section
         
     | 
| 516 | 
         
            +
                with gr.Group(visible=False) as results_panel:
         
     | 
| 517 | 
         
            +
                    gr.Markdown("## ๐ Battle Report")
         
     | 
| 518 | 
         
            +
                    analysis_output = gr.HTML(label="Victory Analysis", elem_classes="battle-card")
         
     | 
| 519 | 
         
            +
                    restart_button = gr.Button("๐ New Challenge", variant="secondary")
         
     | 
| 520 | 
         
            +
             
     | 
| 521 | 
         
            +
                # Upload button click event
         
     | 
| 522 | 
         
            +
                def upload_image(img):
         
     | 
| 523 | 
         
            +
                    if img is None:
         
     | 
| 524 | 
         
            +
                        return None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
         
     | 
| 525 | 
         
            +
                    return img, gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
         
     | 
| 526 | 
         
            +
             
     | 
| 527 | 
         
            +
                upload_button.click(
         
     | 
| 528 | 
         
            +
                    fn=upload_image,
         
     | 
| 529 | 
         
            +
                    inputs=[image_input],
         
     | 
| 530 | 
         
            +
                    outputs=[image_state, prediction_selection, outputs_panel, results_panel]
         
     | 
| 531 | 
         
            +
                )
         
     | 
| 532 | 
         
            +
             
     | 
| 533 | 
         
            +
                # Category selection event
         
     | 
| 534 | 
         
            +
                def update_prediction_options(category):
         
     | 
| 535 | 
         
            +
                    if category is None:
         
     | 
| 536 | 
         
            +
                        return None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
         
     | 
| 537 | 
         
            +
                    model_update, frcnn_vis, detr_vis, maskrcnn_vis, mask2former_vis = update_model_choices(category)
         
     | 
| 538 | 
         
            +
                    return category, model_update, frcnn_vis, detr_vis, maskrcnn_vis, mask2former_vis
         
     | 
| 539 | 
         
            +
             
     | 
| 540 | 
         
            +
                category_choice.change(
         
     | 
| 541 | 
         
            +
                    fn=update_prediction_options,
         
     | 
| 542 | 
         
            +
                    inputs=[category_choice],
         
     | 
| 543 | 
         
            +
                    outputs=[category_state, user_opinion, frcnn_threshold, detr_threshold, maskrcnn_threshold, mask2former_threshold]
         
     | 
| 544 | 
         
            +
                )
         
     | 
| 545 | 
         
            +
             
     | 
| 546 | 
         
            +
                # Detect button click event
         
     | 
| 547 | 
         
            +
                def run_detection(image, category, user_opinion, frcnn_threshold, detr_threshold, maskrcnn_threshold, mask2former_threshold):
         
     | 
| 548 | 
         
            +
                    if not category or not user_opinion:
         
     | 
| 549 | 
         
            +
                        return "Please select a category and prediction.", None, None, None, None, "No analysis available.", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
         
     | 
| 550 | 
         
            +
                    
         
     | 
| 551 | 
         
            +
                    def analyze_with_progress(progress=gr.Progress()):
         
     | 
| 552 | 
         
            +
                        progress(0.1, desc="โ๏ธ Models are gearing up...")
         
     | 
| 553 | 
         
            +
                        result = analyze_performance(image, category, user_opinion, frcnn_threshold, detr_threshold, maskrcnn_threshold, mask2former_threshold)
         
     | 
| 554 | 
         
            +
                        progress(1.0, desc="โ
 Battle complete!")
         
     | 
| 555 | 
         
            +
                        return result
         
     | 
| 556 | 
         
            +
                    
         
     | 
| 557 | 
         
            +
                    try:
         
     | 
| 558 | 
         
            +
                        message, frcnn_result_img, detr_result_img, maskrcnn_result_img, mask2former_result_img, html_analysis = analyze_with_progress()
         
     | 
| 559 | 
         
            +
                        return [
         
     | 
| 560 | 
         
            +
                            message,
         
     | 
| 561 | 
         
            +
                            gr.update(value=frcnn_result_img, visible=category == "Object Detection"),
         
     | 
| 562 | 
         
            +
                            gr.update(value=detr_result_img, visible=category == "Object Detection"),
         
     | 
| 563 | 
         
            +
                            gr.update(value=maskrcnn_result_img, visible=category == "Object Segmentation"),
         
     | 
| 564 | 
         
            +
                            gr.update(value=mask2former_result_img, visible=category == "Object Segmentation"),
         
     | 
| 565 | 
         
            +
                            html_analysis,
         
     | 
| 566 | 
         
            +
                            gr.update(visible=True),
         
     | 
| 567 | 
         
            +
                            gr.update(visible=True),
         
     | 
| 568 | 
         
            +
                            gr.update(visible=category == "Object Detection"),
         
     | 
| 569 | 
         
            +
                            gr.update(visible=category == "Object Segmentation"),
         
     | 
| 570 | 
         
            +
                            gr.update(visible=False)
         
     | 
| 571 | 
         
            +
                        ]
         
     | 
| 572 | 
         
            +
                    except Exception as e:
         
     | 
| 573 | 
         
            +
                        return [f"Error: {str(e)}"] + [gr.update()]*9 + [gr.update(visible=False)]
         
     | 
| 574 | 
         
            +
             
     | 
| 575 | 
         
            +
                detect_button.click(
         
     | 
| 576 | 
         
            +
                    fn=run_detection,
         
     | 
| 577 | 
         
            +
                    inputs=[image_state, category_state, user_opinion, frcnn_threshold, detr_threshold, maskrcnn_threshold, mask2former_threshold],
         
     | 
| 578 | 
         
            +
                    outputs=[gr.Textbox(visible=False), frcnn_result, detr_result, maskrcnn_result, mask2former_result, 
         
     | 
| 579 | 
         
            +
                            analysis_output, outputs_panel, results_panel, detection_tab, segmentation_tab, loading_status]
         
     | 
| 580 | 
         
            +
                )
         
     | 
| 581 | 
         
            +
             
     | 
| 582 | 
         
            +
                # Restart button click event
         
     | 
| 583 | 
         
            +
                def restart():
         
     | 
| 584 | 
         
            +
                    return None, None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
         
     | 
| 585 | 
         
            +
             
     | 
| 586 | 
         
            +
                restart_button.click(
         
     | 
| 587 | 
         
            +
                    fn=restart,
         
     | 
| 588 | 
         
            +
                    inputs=[],
         
     | 
| 589 | 
         
            +
                    outputs=[image_state, category_state, prediction_selection, outputs_panel, results_panel, frcnn_result, detr_result, maskrcnn_result, mask2former_result, analysis_output, user_opinion, category_choice, detection_tab, segmentation_tab]
         
     | 
| 590 | 
         
            +
                )
         
     | 
| 591 | 
         
            +
             
     | 
| 592 | 
         
            +
                # Example images
         
     | 
| 593 | 
         
            +
                example_images = [
         
     | 
| 594 | 
         
            +
                    os.path.join(os.getcwd(), "TEST_IMG_1.jpg"),
         
     | 
| 595 | 
         
            +
                    os.path.join(os.getcwd(), "TEST_IMG_2.JPG"),
         
     | 
| 596 | 
         
            +
                    os.path.join(os.getcwd(), "TEST_IMG_3.jpg"),
         
     | 
| 597 | 
         
            +
                    os.path.join(os.getcwd(), "TEST_IMG_4.jpg")
         
     | 
| 598 | 
         
            +
                ]
         
     | 
| 599 | 
         
            +
             
     | 
| 600 | 
         
            +
                valid_examples = [img for img in example_images if os.path.exists(img)]
         
     | 
| 601 | 
         
            +
             
     | 
| 602 | 
         
            +
                if valid_examples:
         
     | 
| 603 | 
         
            +
                    gr.Markdown("## ๐งฉ Try These Example Challenges:")
         
     | 
| 604 | 
         
            +
                    gr.Examples(
         
     | 
| 605 | 
         
            +
                        examples=valid_examples,
         
     | 
| 606 | 
         
            +
                        inputs=image_input,
         
     | 
| 607 | 
         
            +
                        examples_per_page=4,
         
     | 
| 608 | 
         
            +
                        label=""
         
     | 
| 609 | 
         
            +
                    )
         
     | 
| 610 | 
         
            +
             
     | 
| 611 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 612 | 
         
            +
                app.launch(debug=True)
         
     |