Spaces:
Runtime error
Runtime error
| import PIL | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| from yacs.config import CfgNode as CN | |
| from PIL import ImageDraw | |
| from segment_anything import build_sam, SamPredictor | |
| from segment_anything.utils.amg import remove_small_regions | |
| from PIL import ImageDraw, ImageFont | |
| import groundingdino.util.transforms as T | |
| from constants.constant import DARKER_COLOR_MAP, LIGHTER_COLOR_MAP, COLORS | |
| from groundingdino import build_groundingdino | |
| from groundingdino.util.predict import predict | |
| from groundingdino.util.utils import clean_state_dict | |
| def load_groundingdino_model(model_config_path, model_checkpoint_path): | |
| args = CN.load_cfg(open(model_config_path, "r")) | |
| model = build_groundingdino(args) | |
| checkpoint = torch.load(model_checkpoint_path, map_location="cpu") | |
| load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) | |
| print('loading GroundingDINO:', load_res) | |
| _ = model.eval() | |
| return model | |
| class GroundingModule(nn.Module): | |
| def __init__(self, device='cpu'): | |
| super().__init__() | |
| self.device = device | |
| sam_checkpoint = "./checkpoints/sam_vit_h_4b8939.pth" | |
| groundingdino_checkpoint = "./checkpoints/groundingdino_swint_ogc.pth" | |
| groundingdino_config_file = "./eval_configs/GroundingDINO_SwinT_OGC.yaml" | |
| self.grounding_model = load_groundingdino_model(groundingdino_config_file, | |
| groundingdino_checkpoint).to(device) | |
| self.grounding_model.eval() | |
| sam = build_sam(checkpoint=sam_checkpoint).to(device) | |
| sam.eval() | |
| self.sam_predictor = SamPredictor(sam) | |
| def prompt2mask(self, original_image, prompt, state, box_threshold=0.35, text_threshold=0.25, num_boxes=10): | |
| def image_transform_grounding(init_image): | |
| transform = T.Compose([ | |
| T.RandomResize([800], max_size=1333), | |
| T.ToTensor(), | |
| T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| image, _ = transform(init_image, None) # 3, h, w | |
| return init_image, image | |
| image_np = np.array(original_image, dtype=np.uint8) | |
| prompt = prompt.lower() | |
| prompt = prompt.strip() | |
| if not prompt.endswith("."): | |
| prompt = prompt + "." | |
| _, image_tensor = image_transform_grounding(original_image) | |
| print('==> Box grounding with "{}"...'.format(prompt)) | |
| with torch.cuda.amp.autocast(enabled=True): | |
| boxes, logits, phrases = predict(self.grounding_model, | |
| image_tensor, prompt, box_threshold, text_threshold, device=self.device) | |
| print(phrases) | |
| # from PIL import Image, ImageDraw, ImageFont | |
| H, W = original_image.size[1], original_image.size[0] | |
| draw_img = original_image.copy() | |
| draw = ImageDraw.Draw(draw_img) | |
| color_boxes = [] | |
| color_masks = [] | |
| local_results = [original_image.copy() for _ in range(len(state['entity']))] | |
| local2entity = {} | |
| for obj_ind, (box, label) in enumerate(zip(boxes, phrases)): | |
| # from 0..1 to 0..W, 0..H | |
| box = box * torch.Tensor([W, H, W, H]) | |
| # from xywh to xyxy | |
| box[:2] -= box[2:] / 2 | |
| box[2:] += box[:2] | |
| # random color | |
| for i, s in enumerate(state['entity']): | |
| # print(label.lower(), i[0].lower(), label.lower() == i[0].lower()) | |
| if label.lower() == s[0].lower(): | |
| local2entity[obj_ind] = i | |
| break | |
| if obj_ind not in local2entity: | |
| print('Color not found', label) | |
| color = "grey" # In grey mode. | |
| # tuple(np.random.randint(0, 255, size=3).tolist()) | |
| else: | |
| for i, s in enumerate(state['entity']): | |
| # print(label.lower(), i[0].lower(), label.lower() == i[0].lower()) | |
| if label.lower() == s[0].lower(): | |
| local2entity[obj_ind] = i | |
| break | |
| if obj_ind not in local2entity: | |
| print('Color not found', label) | |
| color = tuple(np.random.randint(0, 255, size=3).tolist()) | |
| else: | |
| color = state['entity'][local2entity[obj_ind]][3] | |
| color_boxes.append(color) | |
| print(color_boxes) | |
| # draw | |
| x0, y0, x1, y1 = box | |
| x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) | |
| draw.rectangle([x0, y0, x1, y1], outline=color, width=10) | |
| # font = ImageFont.load_default() | |
| font = ImageFont.truetype('InputSans-Regular.ttf', int(H / 512.0 * 30)) | |
| if hasattr(font, "getbbox"): | |
| bbox = draw.textbbox((x0, y0), str(label), font) | |
| else: | |
| w, h = draw.textsize(str(label), font) | |
| bbox = (x0, y0, w + x0, y0 + h) | |
| draw.rectangle(bbox, fill=color) | |
| draw.text((x0, y0), str(label), fill="white", font=font) | |
| if obj_ind in local2entity: | |
| local_draw = ImageDraw.Draw(local_results[local2entity[obj_ind]]) | |
| local_draw.rectangle([x0, y0, x1, y1], outline=color, width=10) | |
| local_draw.rectangle(bbox, fill=color) | |
| local_draw.text((x0, y0), str(label), fill="white", font=font) | |
| if boxes.size(0) > 0: | |
| print('==> Mask grounding...') | |
| boxes = boxes * torch.Tensor([W, H, W, H]) | |
| boxes[:, :2] = boxes[:, :2] - boxes[:, 2:] / 2 | |
| boxes[:, 2:] = boxes[:, 2:] + boxes[:, :2] | |
| self.sam_predictor.set_image(image_np) | |
| transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(boxes, image_np.shape[:2]) | |
| with torch.cuda.amp.autocast(enabled=True): | |
| masks, _, _ = self.sam_predictor.predict_torch( | |
| point_coords=None, | |
| point_labels=None, | |
| boxes=transformed_boxes.to(self.device), | |
| multimask_output=False, | |
| ) | |
| # remove small disconnected regions and holes | |
| fine_masks = [] | |
| for mask in masks.to('cpu').numpy(): # masks: [num_masks, 1, h, w] | |
| fine_masks.append(remove_small_regions(mask[0], 400, mode="holes")[0]) | |
| masks = np.stack(fine_masks, axis=0)[:, np.newaxis] | |
| masks = torch.from_numpy(masks) | |
| num_obj = min(len(logits), num_boxes) | |
| mask_map = None | |
| full_img = None | |
| for obj_ind in range(num_obj): | |
| # box = boxes[obj_ind] | |
| m = masks[obj_ind][0] | |
| if full_img is None: | |
| full_img = np.zeros((m.shape[0], m.shape[1], 3)) | |
| mask_map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16) | |
| local_image = np.zeros((m.shape[0], m.shape[1], 3)) | |
| mask_map[m != 0] = obj_ind + 1 | |
| # color_mask = np.random.random((1, 3)).tolist()[0] | |
| color_mask = np.array(color_boxes[obj_ind]) / 255.0 | |
| full_img[m != 0] = color_mask | |
| local_image[m != 0] = color_mask | |
| # if local_results[local2entity[obj_ind]] is not None: | |
| # local_image[m == 0] = np.asarray(local_results[local2entity[obj_ind]])[m == 0] | |
| local_image = (local_image * 255).astype(np.uint8) | |
| local_image = PIL.Image.fromarray(local_image) | |
| if local_results[local2entity[obj_ind]] is not None: | |
| local_results[local2entity[obj_ind]] = PIL.Image.blend(local_results[local2entity[obj_ind]], | |
| local_image, 0.5) | |
| full_img = (full_img * 255).astype(np.uint8) | |
| full_img = PIL.Image.fromarray(full_img) | |
| draw_img = PIL.Image.blend(draw_img, full_img, 0.5) | |
| return draw_img, local_results | |
| # def draw_text(self, entity_state, entity, text): | |
| # local_img = entity_state['grounding']['local'][entity]['image'].copy() | |
| # H, W = local_img.width, local_img.height | |
| # font = ImageFont.truetype('InputSans-Regular.ttf', int(min(H, W) / 512.0 * 30)) | |
| # | |
| # for x0, y0 in entity_state['grounding']['local'][entity]['text_positions']: | |
| # color = entity_state['grounding']['local'][entity]['color'] | |
| # local_draw = ImageDraw.Draw(local_img) | |
| # if hasattr(font, "getbbox"): | |
| # bbox = local_draw.textbbox((x0, y0), str(text), font) | |
| # else: | |
| # w, h = local_draw.textsize(str(text), font) | |
| # bbox = (x0, y0, w + x0, y0 + h) | |
| # | |
| # local_draw.rectangle(bbox, fill=DARKER_COLOR_MAP[color]) | |
| # local_draw.text((x0, y0), str(text), fill="white", font=font) | |
| # return local_img | |
| def draw(self, original_image, entity_state, item=None): | |
| original_image = original_image.copy() | |
| W, H = original_image.width, original_image.height | |
| font = ImageFont.truetype('InputSans-Regular.ttf', int(min(H, W) / 512.0 * 30)) | |
| local_image = np.zeros((H, W, 3)) | |
| local_mask = np.zeros((H, W), dtype=bool) | |
| def draw_item(img, item): | |
| nonlocal local_image, local_mask | |
| entity = entity_state['match_state'][item] | |
| ei = entity_state['grounding']['local'][entity] | |
| color = ei['color'] | |
| local_draw = ImageDraw.Draw(img) | |
| for x0, y0, x1, y1 in ei['entity_positions']: | |
| local_draw.rectangle([x0, y0, x1, y1], outline=DARKER_COLOR_MAP[color], | |
| width=int(min(H, W) / 512.0 * 10)) | |
| for x0, y0 in ei['text_positions']: | |
| if hasattr(font, "getbbox"): | |
| bbox = local_draw.textbbox((x0, y0), str(item), font) | |
| else: | |
| w, h = local_draw.textsize(str(item), font) | |
| bbox = (x0, y0, w + x0, y0 + h) | |
| local_draw.rectangle(bbox, fill=DARKER_COLOR_MAP[color]) | |
| local_draw.text((x0, y0), str(item), fill="white", font=font) | |
| for m in ei['masks']: | |
| local_image[m != 0] = np.array(LIGHTER_COLOR_MAP[color]) / 255.0 | |
| local_mask = np.logical_or(local_mask, m) | |
| # local_image = (local_image * 255).astype(np.uint8) | |
| # local_image = PIL.Image.fromarray(local_image) | |
| # img = PIL.Image.blend(img, local_image, 0.5) | |
| return img | |
| if item is None: | |
| for item in entity_state['match_state'].keys(): | |
| original_image = draw_item(original_image, item) | |
| else: | |
| original_image = draw_item(original_image, item) | |
| local_image[local_mask == 0] = (np.array(original_image) / 255.0)[local_mask == 0] | |
| local_image = (local_image * 255).astype(np.uint8) | |
| local_image = PIL.Image.fromarray(local_image) | |
| original_image = PIL.Image.blend(original_image, local_image, 0.5) | |
| return original_image | |
| def prompt2mask2(self, original_image, prompt, state, box_threshold=0.25, | |
| text_threshold=0.2, iou_threshold=0.5, num_boxes=10): | |
| def image_transform_grounding(init_image): | |
| transform = T.Compose([ | |
| T.RandomResize([800], max_size=1333), | |
| T.ToTensor(), | |
| T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| image, _ = transform(init_image, None) # 3, h, w | |
| return init_image, image | |
| image_np = np.array(original_image, dtype=np.uint8) | |
| prompt = prompt.lower() | |
| prompt = prompt.strip() | |
| if not prompt.endswith("."): | |
| prompt = prompt + "." | |
| _, image_tensor = image_transform_grounding(original_image) | |
| print('==> Box grounding with "{}"...'.format(prompt)) | |
| with torch.cuda.amp.autocast(enabled=True): | |
| boxes, logits, phrases = predict(self.grounding_model, | |
| image_tensor, prompt, box_threshold, text_threshold, device=self.device) | |
| print('==> Box grounding results {}...'.format(phrases)) | |
| # boxes_filt = boxes.cpu() | |
| # # use NMS to handle overlapped boxes | |
| # print(f"==> Before NMS: {boxes_filt.shape[0]} boxes") | |
| # nms_idx = torchvision.ops.nms(boxes_filt, logits, iou_threshold).numpy().tolist() | |
| # boxes_filt = boxes_filt[nms_idx] | |
| # phrases = [phrases[idx] for idx in nms_idx] | |
| # print(f"==> After NMS: {boxes_filt.shape[0]} boxes") | |
| # boxes = boxes_filt | |
| # from PIL import Image, ImageDraw, ImageFont | |
| H, W = original_image.size[1], original_image.size[0] | |
| draw_img = original_image.copy() | |
| draw = ImageDraw.Draw(draw_img) | |
| color_boxes = [] | |
| color_masks = [] | |
| entity_dict = {} | |
| for obj_ind, (box, label) in enumerate(zip(boxes, phrases)): | |
| # from 0..1 to 0..W, 0..H | |
| box = box * torch.Tensor([W, H, W, H]) | |
| # from xywh to xyxy | |
| box[:2] -= box[2:] / 2 | |
| box[2:] += box[:2] | |
| if label not in entity_dict: | |
| entity_dict[label] = { | |
| 'color': COLORS[len(entity_dict) % (len(COLORS) - 1)], | |
| # 'image': original_image.copy(), | |
| 'text_positions': [], | |
| 'entity_positions': [], | |
| 'masks': [] | |
| } | |
| color = entity_dict[label]['color'] | |
| color_boxes.append(DARKER_COLOR_MAP[color]) | |
| color_masks.append(LIGHTER_COLOR_MAP[color]) | |
| # draw | |
| x0, y0, x1, y1 = box | |
| x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) | |
| draw.rectangle([x0, y0, x1, y1], outline=DARKER_COLOR_MAP[color], width=10) | |
| font = ImageFont.truetype('InputSans-Regular.ttf', int(min(H, W) / 512.0 * 30)) | |
| if hasattr(font, "getbbox"): | |
| bbox = draw.textbbox((x0, y0), str(label), font) | |
| else: | |
| w, h = draw.textsize(str(label), font) | |
| bbox = (x0, y0, w + x0, y0 + h) | |
| draw.rectangle(bbox, fill=DARKER_COLOR_MAP[color]) | |
| draw.text((x0, y0), str(label), fill="white", font=font) | |
| # local_img = entity_dict[label]['image'] | |
| # local_draw = ImageDraw.Draw(local_img) | |
| # local_draw.rectangle([x0, y0, x1, y1], outline=DARKER_COLOR_MAP[color], width=10) | |
| entity_dict[label]['text_positions'].append((x0, y0)) | |
| entity_dict[label]['entity_positions'].append((x0, y0, x1, y1)) | |
| # local_draw.rectangle(bbox, fill=DARKER_COLOR_MAP[color]) | |
| # local_draw.text((x0, y0), str(label), fill="white", font=font) | |
| if boxes.size(0) > 0: | |
| print('==> Mask grounding...') | |
| boxes = boxes * torch.Tensor([W, H, W, H]) | |
| boxes[:, :2] = boxes[:, :2] - boxes[:, 2:] / 2 | |
| boxes[:, 2:] = boxes[:, 2:] + boxes[:, :2] | |
| self.sam_predictor.set_image(image_np) | |
| transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(boxes, | |
| image_np.shape[:2]).to(self.device) | |
| with torch.cuda.amp.autocast(enabled=True): | |
| masks, _, _ = self.sam_predictor.predict_torch( | |
| point_coords=None, | |
| point_labels=None, | |
| boxes=transformed_boxes.to(self.device), | |
| multimask_output=False, | |
| ) | |
| # remove small disconnected regions and holes | |
| fine_masks = [] | |
| for mask in masks.to('cpu').numpy(): # masks: [num_masks, 1, h, w] | |
| fine_masks.append(remove_small_regions(mask[0], 400, mode="holes")[0]) | |
| masks = np.stack(fine_masks, axis=0)[:, np.newaxis] | |
| masks = torch.from_numpy(masks) | |
| mask_map = None | |
| full_img = None | |
| for obj_ind, (box, label) in enumerate(zip(boxes, phrases)): | |
| m = masks[obj_ind][0] | |
| if full_img is None: | |
| full_img = np.zeros((m.shape[0], m.shape[1], 3)) | |
| mask_map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16) | |
| # local_image = np.zeros((m.shape[0], m.shape[1], 3)) | |
| mask_map[m != 0] = obj_ind + 1 | |
| color_mask = np.array(color_masks[obj_ind]) / 255.0 | |
| full_img[m != 0] = color_mask | |
| entity_dict[label]['masks'].append(m) | |
| # local_image[m != 0] = color_mask | |
| # local_image[m == 0] = (np.array(entity_dict[label]['image']) / 255.0)[m == 0] | |
| # | |
| # local_image = (local_image * 255).astype(np.uint8) | |
| # local_image = PIL.Image.fromarray(local_image) | |
| # entity_dict[label]['image'] = PIL.Image.blend(entity_dict[label]['image'], local_image, 0.5) | |
| full_img = (full_img * 255).astype(np.uint8) | |
| full_img = PIL.Image.fromarray(full_img) | |
| draw_img = PIL.Image.blend(draw_img, full_img, 0.5) | |
| print('==> Entity list: {}'.format(list(entity_dict.keys()))) | |
| return draw_img, entity_dict | |