import gradio as gr import torch #import spaces import json import base64 from io import BytesIO from transformers import SamHQModel, SamHQProcessor, SamModel, SamProcessor import os import pandas as pd from utils import * from PIL import Image from gradio_image_prompter import ImagePrompter sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-huge") sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-huge") sam_model = SamModel.from_pretrained("facebook/sam-vit-huge") sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") #@spaces.GPU def predict_masks_and_scores(model, processor, raw_image, input_points=None, input_boxes=None): if input_boxes is not None: input_boxes = [input_boxes] inputs = processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) scores = outputs.iou_scores return masks, scores def process_inputs(prompts): raw_entries = prompts["points"] input_points = [] input_boxes = [] for entry in raw_entries: x1, y1, type_, x2, y2, cls = entry if type_ == 1: input_points.append([int(x1), int(y1)]) elif type_ == 2: x_min = int(min(x1, x2)) y_min = int(min(y1, y2)) x_max = int(max(x1, x2)) y_max = int(max(y1, y2)) input_boxes.append([x_min, y_min, x_max, y_max]) input_boxes = [input_boxes] if input_boxes else None input_points = [input_points] if input_points else None user_image = prompts['image'] sam_masks, sam_scores = predict_masks_and_scores(sam_model, sam_processor, user_image, input_boxes=input_boxes, input_points=input_points) sam_hq_masks, sam_hq_scores = predict_masks_and_scores(sam_hq_model, sam_hq_processor, user_image, input_boxes=input_boxes, input_points=input_points) if input_boxes and input_points: img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM') img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM_HQ') elif input_boxes: img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], None, model_name='SAM') img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], None, model_name='SAM_HQ') elif input_points: img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], None, input_points[0], model_name='SAM') img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], None, input_points[0], model_name='SAM_HQ') print('user_image', user_image) print("img1_b64", img1_b64) print("img2_b64", img2_b64) html_code = f"""