Spaces:
Paused
Paused
| import os | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from transformers import AutoModel, AutoProcessor | |
| from ultralytics import YOLO | |
| # Custom CSS for shadcn/Radix UI inspired look | |
| custom_css = """ | |
| :root { | |
| --primary: #0f172a; | |
| --primary-foreground: #f8fafc; | |
| --background: #f8fafc; | |
| --card: #ffffff; | |
| --card-foreground: #0f172a; | |
| --border: #e2e8f0; | |
| --ring: #94a3b8; | |
| --radius: 0.5rem; | |
| } | |
| .dark { | |
| --primary: #f8fafc; | |
| --primary-foreground: #0f172a; | |
| --background: #0f172a; | |
| --card: #1e293b; | |
| --card-foreground: #f8fafc; | |
| --border: #334155; | |
| --ring: #94a3b8; | |
| } | |
| .gradio-container { | |
| margin: 0 !important; | |
| padding: 0 !important; | |
| max-width: 100% !important; | |
| } | |
| .main-container { | |
| background-color: var(--background); | |
| border-radius: var(--radius); | |
| padding: 1.5rem; | |
| } | |
| .header { | |
| margin-bottom: 1.5rem; | |
| border-bottom: 1px solid var(--border); | |
| padding-bottom: 1rem; | |
| } | |
| .header h1 { | |
| font-size: 1.875rem; | |
| font-weight: 700; | |
| color: var(--primary); | |
| margin-bottom: 0.5rem; | |
| } | |
| .header p { | |
| color: var(--card-foreground); | |
| opacity: 0.8; | |
| } | |
| .tab-nav { | |
| background-color: var(--card); | |
| border: 1px solid var(--border); | |
| border-radius: var(--radius); | |
| padding: 0.25rem; | |
| margin-bottom: 1.5rem; | |
| } | |
| .tab-nav button { | |
| border-radius: calc(var(--radius) - 0.25rem) !important; | |
| font-weight: 500 !important; | |
| transition: all 0.2s ease-in-out !important; | |
| } | |
| .tab-nav button.selected { | |
| background-color: var(--primary) !important; | |
| color: var(--primary-foreground) !important; | |
| } | |
| .input-panel, .output-panel { | |
| background-color: var(--card); | |
| border: 1px solid var(--border); | |
| border-radius: var(--radius); | |
| padding: 1.5rem; | |
| box-shadow: 0 1px 3px rgba(0, 0, 0, 0.05); | |
| } | |
| .gr-button-primary { | |
| background-color: var(--primary) !important; | |
| color: var(--primary-foreground) !important; | |
| border-radius: var(--radius) !important; | |
| font-weight: 500 !important; | |
| transition: all 0.2s ease-in-out !important; | |
| } | |
| .gr-button-primary:hover { | |
| opacity: 0.9 !important; | |
| } | |
| .gr-form { | |
| border: none !important; | |
| background: transparent !important; | |
| } | |
| .gr-input, .gr-select { | |
| border: 1px solid var(--border) !important; | |
| border-radius: var(--radius) !important; | |
| padding: 0.5rem 0.75rem !important; | |
| } | |
| .gr-panel { | |
| border: none !important; | |
| } | |
| .footer { | |
| margin-top: 1.5rem; | |
| border-top: 1px solid var(--border); | |
| padding-top: 1rem; | |
| font-size: 0.875rem; | |
| color: var(--card-foreground); | |
| opacity: 0.7; | |
| }""" | |
| # Custom CSS for a more modern UI inspired by NextUI | |
| custom_css = """ | |
| :root { | |
| --primary: #0070f3; | |
| --primary-foreground: #ffffff; | |
| --background: #f5f5f5; | |
| --card: #ffffff; | |
| --card-foreground: #111111; | |
| --border: #eaeaea; | |
| --ring: #0070f3; | |
| --shadow: 0 4px 14px 0 rgba(0, 118, 255, 0.1); | |
| } | |
| .dark { | |
| --primary: #0070f3; | |
| --primary-foreground: #ffffff; | |
| --background: #000000; | |
| --card: #111111; | |
| --card-foreground: #ffffff; | |
| --border: #333333; | |
| --ring: #0070f3; | |
| } | |
| .gradio-container { | |
| margin: 0 !important; | |
| padding: 0 !important; | |
| max-width: 100% !important; | |
| } | |
| .main-container { | |
| background-color: var(--background); | |
| padding: 2rem; | |
| } | |
| .header { | |
| margin-bottom: 2rem; | |
| text-align: center; | |
| } | |
| .header h1 { | |
| font-size: 2.5rem; | |
| font-weight: 800; | |
| color: var(--card-foreground); | |
| margin-bottom: 0.5rem; | |
| background: linear-gradient(to right, #0070f3, #00bfff); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| } | |
| .header p { | |
| color: var(--card-foreground); | |
| opacity: 0.8; | |
| font-size: 1.1rem; | |
| } | |
| .tab-nav { | |
| background-color: var(--card); | |
| border-radius: var(--radius); | |
| padding: 0.5rem; | |
| margin-bottom: 2rem; | |
| box-shadow: var(--shadow); | |
| } | |
| .tab-nav button { | |
| border-radius: var(--radius) !important; | |
| font-weight: 600 !important; | |
| transition: all 0.2s ease-in-out !important; | |
| padding: 0.75rem 1.5rem !important; | |
| } | |
| .tab-nav button.selected { | |
| background-color: var(--primary) !important; | |
| color: var(--primary-foreground) !important; | |
| transform: translateY(-2px); | |
| box-shadow: 0 4px 14px 0 rgba(0, 118, 255, 0.25); | |
| } | |
| .input-panel, .output-panel { | |
| background-color: var(--card); | |
| border-radius: var(--radius); | |
| padding: 1.5rem; | |
| box-shadow: var(--shadow); | |
| height: 100%; | |
| display: flex; | |
| flex-direction: column; | |
| } | |
| .input-panel h3, .output-panel h3 { | |
| font-size: 1.25rem; | |
| font-weight: 600; | |
| margin-bottom: 1rem; | |
| color: var(--card-foreground); | |
| border-bottom: 2px solid var(--primary); | |
| padding-bottom: 0.5rem; | |
| display: inline-block; | |
| } | |
| .gr-button-primary { | |
| background-color: var(--primary) !important; | |
| color: var(--primary-foreground) !important; | |
| border-radius: var(--radius) !important; | |
| font-weight: 600 !important; | |
| transition: all 0.2s ease-in-out !important; | |
| padding: 0.75rem 1.5rem !important; | |
| box-shadow: 0 4px 14px 0 rgba(0, 118, 255, 0.25) !important; | |
| width: 100%; | |
| margin-top: 1rem; | |
| } | |
| .gr-button-primary:hover { | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 6px 20px rgba(0, 118, 255, 0.35) !important; | |
| } | |
| .gr-form { | |
| border: none !important; | |
| background: transparent !important; | |
| } | |
| .gr-input, .gr-select { | |
| border: 1px solid var(--border) !important; | |
| border-radius: var(--radius) !important; | |
| padding: 0.75rem 1rem !important; | |
| transition: all 0.2s ease-in-out !important; | |
| } | |
| .gr-input:focus, .gr-select:focus { | |
| border-color: var(--primary) !important; | |
| box-shadow: 0 0 0 2px rgba(0, 118, 255, 0.25) !important; | |
| } | |
| .gr-panel { | |
| border: none !important; | |
| } | |
| .gr-accordion { | |
| border: 1px solid var(--border) !important; | |
| border-radius: var(--radius) !important; | |
| overflow: hidden; | |
| } | |
| .footer { | |
| margin-top: 2rem; | |
| border-top: 1px solid var(--border); | |
| padding-top: 1.5rem; | |
| font-size: 0.9rem; | |
| color: var(--card-foreground); | |
| opacity: 0.7; | |
| text-align: center; | |
| } | |
| .footer-card { | |
| background-color: var(--card); | |
| border-radius: var(--radius); | |
| padding: 1.5rem; | |
| box-shadow: var(--shadow); | |
| } | |
| .tips-grid { | |
| display: grid; | |
| grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); | |
| gap: 1rem; | |
| margin-top: 1rem; | |
| } | |
| .tip-card { | |
| background-color: var(--card); | |
| border-radius: var(--radius); | |
| padding: 1rem; | |
| border-left: 3px solid var(--primary); | |
| } | |
| """ | |
| # Available model sizes | |
| DETECTION_MODELS = { | |
| "small": "yolov8s-worldv2.pt", | |
| "medium": "yolov8m-worldv2.pt", | |
| "large": "yolov8l-worldv2.pt", | |
| "xlarge": "yolov8x-worldv2.pt", | |
| } | |
| SEGMENTATION_MODELS = { | |
| "YOLOv8 Nano": "yolov8n-seg.pt", | |
| "YOLOv8 Small": "yolov8s-seg.pt", | |
| "YOLOv8 Medium": "yolov8m-seg.pt", | |
| "YOLOv8 Large": "yolov8l-seg.pt", | |
| } | |
| class YOLOWorldDetector: | |
| def __init__(self, model_size="small"): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model_size = model_size | |
| self.model_name = DETECTION_MODELS[model_size] | |
| print(f"Loading {self.model_name} on {self.device}...") | |
| try: | |
| # Try to load using Ultralytics YOLOWorld | |
| from ultralytics import YOLOWorld | |
| self.model = YOLOWorld(self.model_name) | |
| self.model_type = "yoloworld" | |
| print("YOLOWorld model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading YOLOWorld model: {e}") | |
| print("Falling back to standard YOLOv8 for detection...") | |
| # Fallback to YOLOv8 | |
| self.model = YOLO("yolov8n.pt") | |
| self.model_type = "yolov8" | |
| print("YOLOv8 fallback model loaded successfully!") | |
| # Segmentation models | |
| self.seg_models = {} | |
| def change_model(self, model_size): | |
| if model_size != self.model_size: | |
| self.model_size = model_size | |
| self.model_name = DETECTION_MODELS[model_size] | |
| print(f"Loading {self.model_name} on {self.device}...") | |
| try: | |
| # Try to load using Ultralytics YOLOWorld | |
| from ultralytics import YOLOWorld | |
| self.model = YOLOWorld(self.model_name) | |
| self.model_type = "yoloworld" | |
| print("YOLOWorld model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading YOLOWorld model: {e}") | |
| print("Falling back to standard YOLOv8 for detection...") | |
| # Fallback to YOLOv8 | |
| self.model = YOLO("yolov8n.pt") | |
| self.model_type = "yolov8" | |
| print("YOLOv8 fallback model loaded successfully!") | |
| return f"Using {self.model_name} model" | |
| def load_seg_model(self, model_name): | |
| if model_name not in self.seg_models: | |
| print(f"Loading segmentation model {model_name}...") | |
| self.seg_models[model_name] = YOLO(SEGMENTATION_MODELS[model_name]) | |
| print(f"Segmentation model {model_name} loaded successfully!") | |
| return self.seg_models[model_name] | |
| def detect(self, image, text_prompt, confidence_threshold=0.3): | |
| if image is None: | |
| return None, "No image provided" | |
| # Process the image | |
| if isinstance(image, str): | |
| img_for_json = cv2.imread(image) | |
| elif isinstance(image, np.ndarray): | |
| img_for_json = image.copy() | |
| else: | |
| # Convert PIL Image to numpy array if needed | |
| img_for_json = np.array(image) | |
| # Run inference based on model type | |
| if self.model_type == "yoloworld": | |
| try: | |
| # Parse text prompt properly for YOLOWorld | |
| if text_prompt and text_prompt.strip(): | |
| # Split by comma and strip whitespace | |
| classes = [cls.strip() for cls in text_prompt.split(',') if cls.strip()] | |
| else: | |
| classes = None | |
| self.model.set_classes(classes) | |
| # YOLOWorld supports text prompts | |
| results = self.model.predict( | |
| source=image, | |
| conf=confidence_threshold, | |
| ) | |
| except Exception as e: | |
| print(f"Error during YOLOWorld inference: {e}") | |
| print("Falling back to standard YOLO inference...") | |
| # If YOLOWorld inference fails, use standard YOLO | |
| results = self.model.predict( | |
| source=image, | |
| conf=confidence_threshold, | |
| verbose=False | |
| ) | |
| else: | |
| # Standard YOLO doesn't use text prompts | |
| results = self.model.predict( | |
| source=image, | |
| conf=confidence_threshold, | |
| verbose=False | |
| ) | |
| # Get the plotted result | |
| res_plotted = results[0].plot() | |
| # Convert results to JSON format (percentages) | |
| json_results = [] | |
| img_height, img_width = img_for_json.shape[:2] | |
| for i, (box, cls, conf) in enumerate(zip( | |
| results[0].boxes.xyxy.cpu().numpy(), | |
| results[0].boxes.cls.cpu().numpy(), | |
| results[0].boxes.conf.cpu().numpy() | |
| )): | |
| x1, y1, x2, y2 = box | |
| json_results.append({ | |
| "bbox": { | |
| "x": (x1 / img_width) * 100, | |
| "y": (y1 / img_height) * 100, | |
| "width": ((x2 - x1) / img_width) * 100, | |
| "height": ((y2 - y1) / img_height) * 100 | |
| }, | |
| "score": float(conf), | |
| "label": int(cls), | |
| "label_text": results[0].names[int(cls)] | |
| }) | |
| return res_plotted, json_results | |
| def segment(self, image, model_name, confidence_threshold=0.3): | |
| if image is None: | |
| return None, "No image provided" | |
| # Load segmentation model if not already loaded | |
| model = self.load_seg_model(model_name) | |
| # Run inference | |
| results = model(image, conf=confidence_threshold) | |
| # Create visualization | |
| fig, ax = plt.subplots(1, 1, figsize=(12, 9)) | |
| ax.axis('off') | |
| # Plot segmentation results | |
| res_plotted = results[0].plot() | |
| # Convert results to JSON format (percentages) | |
| json_results = [] | |
| if hasattr(results[0], 'masks') and results[0].masks is not None: | |
| img_height, img_width = results[0].orig_shape | |
| for i, (box, mask, cls, conf) in enumerate(zip( | |
| results[0].boxes.xyxy.cpu().numpy(), | |
| results[0].masks.data.cpu().numpy(), | |
| results[0].boxes.cls.cpu().numpy(), | |
| results[0].boxes.conf.cpu().numpy() | |
| )): | |
| x1, y1, x2, y2 = box | |
| # Convert mask to polygon for SVG-like representation | |
| # Simplified approach - in production you might want a more sophisticated polygon extraction | |
| contours, _ = cv2.findContours((mask > 0.5).astype(np.uint8), | |
| cv2.RETR_EXTERNAL, | |
| cv2.CHAIN_APPROX_SIMPLE) | |
| if contours: | |
| # Get the largest contour | |
| largest_contour = max(contours, key=cv2.contourArea) | |
| # Simplify the contour | |
| epsilon = 0.005 * cv2.arcLength(largest_contour, True) | |
| approx = cv2.approxPolyDP(largest_contour, epsilon, True) | |
| # Convert to percentage coordinates | |
| points = [] | |
| for point in approx: | |
| x, y = point[0] | |
| points.append({ | |
| "x": (x / img_width) * 100, | |
| "y": (y / img_height) * 100 | |
| }) | |
| json_results.append({ | |
| "bbox": { | |
| "x": (x1 / img_width) * 100, | |
| "y": (y1 / img_height) * 100, | |
| "width": ((x2 - x1) / img_width) * 100, | |
| "height": ((y2 - y1) / img_height) * 100 | |
| }, | |
| "score": float(conf), | |
| "label": int(cls), | |
| "label_text": results[0].names[int(cls)], | |
| "polygon": points | |
| }) | |
| return res_plotted, json_results | |
| # Initialize detector with default model | |
| detector = YOLOWorldDetector(model_size="small") | |
| def create_svg_from_detections(json_results, img_width, img_height): | |
| """Convert detection results to SVG format""" | |
| svg_header = f'<svg width="{img_width}" height="{img_height}" xmlns="http://www.w3.org/2000/svg">' | |
| svg_content = "" | |
| # Color palette for different classes | |
| colors = [ | |
| "#FF3B30", "#FF9500", "#FFCC00", "#4CD964", | |
| "#5AC8FA", "#007AFF", "#5856D6", "#FF2D55" | |
| ] | |
| for i, result in enumerate(json_results): | |
| bbox = result["bbox"] | |
| label = result.get("label_text", f"Object {i}") | |
| score = result.get("score", 0) | |
| # Convert percentage to absolute coordinates | |
| x = (bbox["x"] / 100) * img_width | |
| y = (bbox["y"] / 100) * img_height | |
| width = (bbox["width"] / 100) * img_width | |
| height = (bbox["height"] / 100) * img_height | |
| # Select color based on class index | |
| color = colors[i % len(colors)] | |
| # Create rectangle element | |
| svg_content += f''' | |
| <rect | |
| x="{x:.2f}" | |
| y="{y:.2f}" | |
| width="{width:.2f}" | |
| height="{height:.2f}" | |
| stroke="{color}" | |
| stroke-width="2" | |
| fill="none" | |
| data-label="{label}" | |
| data-score="{score:.2f}" | |
| /> | |
| <text | |
| x="{x:.2f}" | |
| y="{y-5:.2f}" | |
| font-family="Arial" | |
| font-size="12" | |
| fill="{color}" | |
| >{label} ({score:.2f})</text>''' | |
| svg_footer = "\n</svg>" | |
| return svg_header + svg_content + svg_footer | |
| def create_svg_from_segmentation(json_results, img_width, img_height): | |
| """Convert segmentation results to SVG format""" | |
| svg_header = f'<svg width="{img_width}" height="{img_height}" xmlns="http://www.w3.org/2000/svg">' | |
| svg_content = "" | |
| # Color palette for different classes | |
| colors = [ | |
| "#FF3B30", "#FF9500", "#FFCC00", "#4CD964", | |
| "#5AC8FA", "#007AFF", "#5856D6", "#FF2D55" | |
| ] | |
| for i, result in enumerate(json_results): | |
| label = result.get("label_text", f"Object {i}") | |
| score = result.get("score", 0) | |
| # Select color based on class index | |
| color = colors[i % len(colors)] | |
| # Create polygon if available | |
| if "polygon" in result: | |
| points_str = " ".join([ | |
| f"{(p['x']/100)*img_width:.2f},{(p['y']/100)*img_height:.2f}" | |
| for p in result["polygon"] | |
| ]) | |
| svg_content += f''' | |
| <polygon | |
| points="{points_str}" | |
| stroke="{color}" | |
| stroke-width="2" | |
| fill="{color}33" | |
| data-label="{label}" | |
| data-score="{score:.2f}" | |
| />''' | |
| # Also add bounding box | |
| bbox = result["bbox"] | |
| x = (bbox["x"] / 100) * img_width | |
| y = (bbox["y"] / 100) * img_height | |
| width = (bbox["width"] / 100) * img_width | |
| height = (bbox["height"] / 100) * img_height | |
| svg_content += f''' | |
| <rect | |
| x="{x:.2f}" | |
| y="{y:.2f}" | |
| width="{width:.2f}" | |
| height="{height:.2f}" | |
| stroke="{color}" | |
| stroke-width="1" | |
| fill="none" | |
| stroke-dasharray="5,5" | |
| /> | |
| <text | |
| x="{x:.2f}" | |
| y="{y-5:.2f}" | |
| font-family="Arial" | |
| font-size="12" | |
| fill="{color}" | |
| >{label} ({score:.2f})</text>''' | |
| svg_footer = "\n</svg>" | |
| return svg_header + svg_content + svg_footer | |
| def detection_inference(image, text_prompt, confidence, model_size): | |
| # Update model if needed | |
| detector.change_model(model_size) | |
| # Run detection | |
| result_image, json_results = detector.detect( | |
| image, | |
| text_prompt, | |
| confidence_threshold=confidence | |
| ) | |
| # Create SVG from detection results | |
| if isinstance(json_results, list) and len(json_results) > 0: | |
| img_height, img_width = result_image.shape[:2] | |
| svg_output = create_svg_from_detections(json_results, img_width, img_height) | |
| else: | |
| svg_output = "<svg></svg>" | |
| return result_image, str(json_results), svg_output | |
| def segmentation_inference(image, confidence, model_name): | |
| # Run segmentation | |
| result_image, json_results = detector.segment( | |
| image, | |
| model_name, | |
| confidence_threshold=confidence | |
| ) | |
| # Create SVG from segmentation results | |
| if isinstance(json_results, list) and len(json_results) > 0: | |
| img_height, img_width = result_image.shape[:2] | |
| svg_output = create_svg_from_segmentation(json_results, img_width, img_height) | |
| else: | |
| svg_output = "<svg></svg>" | |
| return result_image, str(json_results), svg_output | |
| # Create Gradio interface | |
| with gr.Blocks(title="YOLO Vision Suite", css=custom_css) as demo: | |
| with gr.Column(elem_classes="main-container"): | |
| with gr.Column(elem_classes="header"): | |
| gr.Markdown("# YOLO Vision Suite") | |
| gr.Markdown("Advanced object detection and segmentation powered by YOLO models") | |
| with gr.Tabs(elem_classes="tab-nav") as tabs: | |
| with gr.TabItem("Object Detection", elem_id="detection-tab"): | |
| with gr.Row(equal_height=True): | |
| with gr.Column(elem_classes="input-panel", scale=1): | |
| gr.Markdown("### Input") | |
| input_image = gr.Image(label="Upload Image", type="numpy", height=300) | |
| text_prompt = gr.Textbox( | |
| label="Text Prompt", | |
| placeholder="person, car, dog", | |
| value="person, car, dog", | |
| elem_classes="gr-input" | |
| ) | |
| with gr.Row(): | |
| confidence = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.3, | |
| step=0.05, | |
| label="Confidence Threshold" | |
| ) | |
| model_dropdown = gr.Dropdown( | |
| choices=list(DETECTION_MODELS.keys()), | |
| value="small", | |
| label="Model Size", | |
| elem_classes="gr-select" | |
| ) | |
| detect_button = gr.Button("Detect Objects", elem_classes="gr-button-primary") | |
| with gr.Column(elem_classes="output-panel", scale=1): | |
| gr.Markdown("### Results") | |
| output_image = gr.Image(label="Detection Result", height=300) | |
| with gr.Accordion("SVG Output", open=False, elem_classes="gr-accordion"): | |
| svg_output = gr.HTML(label="SVG Visualization") | |
| with gr.Accordion("JSON Output", open=False, elem_classes="gr-accordion"): | |
| json_output = gr.Textbox( | |
| label="Bounding Box Data (Percentage Coordinates)", | |
| elem_classes="gr-input", | |
| lines=5 | |
| ) | |
| with gr.TabItem("Segmentation", elem_id="segmentation-tab"): | |
| with gr.Row(equal_height=True): | |
| with gr.Column(elem_classes="input-panel", scale=1): | |
| gr.Markdown("### Input") | |
| seg_input_image = gr.Image(label="Upload Image", type="numpy", height=300) | |
| with gr.Row(): | |
| seg_confidence = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.3, | |
| step=0.05, | |
| label="Confidence Threshold" | |
| ) | |
| seg_model_dropdown = gr.Dropdown( | |
| choices=list(SEGMENTATION_MODELS.keys()), | |
| value="YOLOv8 Small", | |
| label="Model Size", | |
| elem_classes="gr-select" | |
| ) | |
| segment_button = gr.Button("Segment Image", elem_classes="gr-button-primary") | |
| with gr.Column(elem_classes="output-panel", scale=1): | |
| gr.Markdown("### Results") | |
| seg_output_image = gr.Image(label="Segmentation Result", height=300) | |
| with gr.Accordion("SVG Output", open=False, elem_classes="gr-accordion"): | |
| seg_svg_output = gr.HTML(label="SVG Visualization") | |
| with gr.Accordion("JSON Output", open=False, elem_classes="gr-accordion"): | |
| seg_json_output = gr.Textbox( | |
| label="Segmentation Data (Percentage Coordinates)", | |
| elem_classes="gr-input", | |
| lines=5 | |
| ) | |
| with gr.Column(elem_classes="footer"): | |
| with gr.Column(elem_classes="footer-card"): | |
| gr.Markdown("### Tips & Information") | |
| with gr.Row(elem_classes="tips-grid"): | |
| with gr.Column(elem_classes="tip-card"): | |
| gr.Markdown("**Detection**") | |
| gr.Markdown("Enter comma-separated text prompts to specify what objects to detect") | |
| with gr.Column(elem_classes="tip-card"): | |
| gr.Markdown("**Segmentation**") | |
| gr.Markdown("The model will identify and segment common objects automatically") | |
| with gr.Column(elem_classes="tip-card"): | |
| gr.Markdown("**Models**") | |
| gr.Markdown("Larger models provide better accuracy but require more processing power") | |
| with gr.Column(elem_classes="tip-card"): | |
| gr.Markdown("**Output**") | |
| gr.Markdown("JSON output provides coordinates as percentages, compatible with SVG") | |
| # Set up event handlers | |
| detect_button.click( | |
| detection_inference, | |
| inputs=[input_image, text_prompt, confidence, model_dropdown], | |
| outputs=[output_image, json_output, svg_output] | |
| ) | |
| segment_button.click( | |
| segmentation_inference, | |
| inputs=[seg_input_image, seg_confidence, seg_model_dropdown], | |
| outputs=[seg_output_image, seg_json_output, seg_svg_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) # Set share=True to create a public link |