File size: 5,705 Bytes
bdec3d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# --------------------------------------------------------
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Xueyan Zou ([email protected])
# --------------------------------------------------------
import os
import sys
import logging
pth = '/'.join(sys.path[0].split('/')[:-1])
sys.path.insert(0, pth)
from PIL import Image
import numpy as np
np.random.seed(0)
import cv2
import torch
from torchvision import transforms
from utils.arguments import load_opt_command
from detectron2.data import MetadataCatalog
from modeling.language.loss import vl_similarity
from modeling.BaseModel import BaseModel
from modeling import build_model
from utils.visualizer import Visualizer
from utils.distributed import init_distributed
logger = logging.getLogger(__name__)
def main(args=None):
'''
Main execution point for PyLearn.
'''
opt, cmdline_args = load_opt_command(args)
if cmdline_args.user_dir:
absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
opt['base_path'] = absolute_user_dir
opt = init_distributed(opt)
# META DATA
pretrained_pth = os.path.join(opt['RESUME_FROM'])
output_root = './output'
image_list = ['inference/images/coco/000.jpg', 'inference/images/coco/001.jpg', 'inference/images/coco/002.jpg', 'inference/images/coco/003.jpg']
text = ['pizza on the plate']
model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda()
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(["background", "background"], is_eval=False)
t = []
t.append(transforms.Resize(224, interpolation=Image.BICUBIC))
transform_ret = transforms.Compose(t)
t = []
t.append(transforms.Resize(512, interpolation=Image.BICUBIC))
transform_grd = transforms.Compose(t)
metadata = MetadataCatalog.get('ade20k_panoptic_train')
color = [0/255, 255/255, 0/255]
with torch.no_grad():
batch_inputs = []
candidate_list = []
for j in range(len(image_list)):
image_ori = Image.open(image_list[j])
width = image_ori.size[0]
height = image_ori.size[1]
image = transform_ret(image_ori)
image = np.asarray(image)
candidate_list += [image]
image_list += [image]
image_ori = np.asarray(image_ori)
images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
batch_inputs += [{'image': images, 'height': height, 'width': width}]
outputs = model.model.evaluate(batch_inputs)
v_emb = torch.cat([x['captions'][-1:] for x in outputs])
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(text, is_eval=False, name='caption', prompt=False)
t_emb = getattr(model.model.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('caption'))
temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale
logits = vl_similarity(v_emb, t_emb, temperature)
max_prob, max_id = logits.softmax(0).max(dim=0)
frame_pth = image_list[max_id.item()]
image_ori = Image.open(frame_pth)
width = image_ori.size[0]
height = image_ori.size[1]
image = transform_grd(image_ori)
image = np.asarray(image)
image_ori = np.asarray(image_ori)
images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
batch_inputs = [{'image': images, 'height': height, 'width': width, 'groundings': {'texts': [text]}}]
outputs = model.model.evaluate_grounding(batch_inputs, None)
visual = Visualizer(image_ori, metadata=metadata)
grd_masks = (outputs[0]['grounding_mask'] > 0).float().cpu().numpy()
for text_, mask in zip(text, grd_masks):
demo = visual.draw_binary_mask(mask, color=color, text='', alpha=0.5)
region_img = demo.get_image()
candidate_list[max_id.item()] = region_img
out_image = np.zeros((224*4+60, 448*4, 3))
for ii, img in enumerate(candidate_list):
img = cv2.resize(img, (448, 224))
if ii != max_id.item():
img = img * 0.4
hs, ws = 60+(ii//4)*224, (ii%4)*448
out_image[hs:hs+224,ws:ws+448,:] = img[:,:,::-1]
font = cv2.FONT_HERSHEY_DUPLEX
fontScale = 1.2
thickness = 3
lineType = 2
bottomLeftCornerOfText = (10, 40)
fontColor = [255,255,255]
cv2.putText(out_image, text[0],
bottomLeftCornerOfText,
font,
fontScale,
fontColor,
thickness,
lineType)
x1 = (max_id.item()%4) * 448
y1 = (max_id.item()//4) * 224 + 60
cv2.rectangle(out_image, (x1, y1), (x1+448, y1+224), (0,0,255), 3)
x1 = x1
y1 = y1 + 224 - 30
cv2.rectangle(out_image, (x1+2, y1), (x1+60, y1+28), (0,0,0), -1)
fontScale = 1.0
thickness = 2
bottomLeftCornerOfText = (x1, y1+21)
cv2.putText(out_image, str(max_prob.item())[0:4],
bottomLeftCornerOfText,
font,
0.8,
[0,0,255],
thickness,
lineType)
if not os.path.exists(output_root):
os.makedirs(output_root)
cv2.imwrite(os.path.join(output_root, 'region_retrieval.png'), out_image)
if __name__ == "__main__":
main()
sys.exit(0) |