File size: 5,705 Bytes
bdec3d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# --------------------------------------------------------
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Xueyan Zou ([email protected])
# --------------------------------------------------------

import os
import sys
import logging

pth = '/'.join(sys.path[0].split('/')[:-1])
sys.path.insert(0, pth)

from PIL import Image
import numpy as np
np.random.seed(0)
import cv2

import torch
from torchvision import transforms

from utils.arguments import load_opt_command

from detectron2.data import MetadataCatalog
from modeling.language.loss import vl_similarity
from modeling.BaseModel import BaseModel
from modeling import build_model
from utils.visualizer import Visualizer
from utils.distributed import init_distributed

logger = logging.getLogger(__name__)


def main(args=None):
    '''
    Main execution point for PyLearn.
    '''
    opt, cmdline_args = load_opt_command(args)
    if cmdline_args.user_dir:
        absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
        opt['base_path'] = absolute_user_dir
    opt = init_distributed(opt)

    # META DATA
    pretrained_pth = os.path.join(opt['RESUME_FROM'])
    output_root = './output'

    image_list = ['inference/images/coco/000.jpg', 'inference/images/coco/001.jpg', 'inference/images/coco/002.jpg', 'inference/images/coco/003.jpg']
    text = ['pizza on the plate']

    model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda()
    model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(["background", "background"], is_eval=False)

    t = []
    t.append(transforms.Resize(224, interpolation=Image.BICUBIC))
    transform_ret = transforms.Compose(t)
    t = []
    t.append(transforms.Resize(512, interpolation=Image.BICUBIC))
    transform_grd = transforms.Compose(t)

    metadata = MetadataCatalog.get('ade20k_panoptic_train')
    color = [0/255, 255/255, 0/255]

    with torch.no_grad():
        batch_inputs = []
        candidate_list = []
        for j in range(len(image_list)):
            image_ori = Image.open(image_list[j])
            width = image_ori.size[0]
            height = image_ori.size[1]

            image = transform_ret(image_ori)
            image = np.asarray(image)
            candidate_list += [image]
            image_list += [image]
            image_ori = np.asarray(image_ori)
            images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
            batch_inputs += [{'image': images, 'height': height, 'width': width}]

        outputs = model.model.evaluate(batch_inputs)
        v_emb = torch.cat([x['captions'][-1:] for x in outputs])
        v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)

        model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(text, is_eval=False, name='caption', prompt=False)
        t_emb = getattr(model.model.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('caption'))
        temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale
        logits = vl_similarity(v_emb, t_emb, temperature)
        max_prob, max_id = logits.softmax(0).max(dim=0)
        
        frame_pth = image_list[max_id.item()]
        image_ori = Image.open(frame_pth)
        width = image_ori.size[0]
        height = image_ori.size[1]

        image = transform_grd(image_ori)
        image = np.asarray(image)
        image_ori = np.asarray(image_ori)
        images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
        batch_inputs = [{'image': images, 'height': height, 'width': width, 'groundings': {'texts': [text]}}]
        outputs = model.model.evaluate_grounding(batch_inputs, None)

        visual = Visualizer(image_ori, metadata=metadata)
        grd_masks = (outputs[0]['grounding_mask'] > 0).float().cpu().numpy()
        for text_, mask in zip(text, grd_masks):
            demo = visual.draw_binary_mask(mask, color=color, text='', alpha=0.5)
        region_img = demo.get_image()
        candidate_list[max_id.item()] = region_img
        out_image = np.zeros((224*4+60, 448*4, 3))
        for ii, img in enumerate(candidate_list):
            img = cv2.resize(img, (448, 224))
            if ii != max_id.item():
                img = img * 0.4
            hs, ws = 60+(ii//4)*224, (ii%4)*448
            out_image[hs:hs+224,ws:ws+448,:] = img[:,:,::-1]

        font                   = cv2.FONT_HERSHEY_DUPLEX
        fontScale              = 1.2
        thickness              = 3
        lineType               = 2
        bottomLeftCornerOfText = (10, 40)
        fontColor              = [255,255,255]
        cv2.putText(out_image, text[0],
            bottomLeftCornerOfText, 
            font, 
            fontScale,
            fontColor,
            thickness,
            lineType)

        x1 = (max_id.item()%4) * 448
        y1 = (max_id.item()//4) * 224 + 60
        cv2.rectangle(out_image, (x1, y1), (x1+448, y1+224), (0,0,255), 3)

        x1 = x1
        y1 = y1 + 224 - 30
        cv2.rectangle(out_image, (x1+2, y1), (x1+60, y1+28), (0,0,0), -1)

        fontScale              = 1.0
        thickness              = 2

        bottomLeftCornerOfText = (x1, y1+21)
        cv2.putText(out_image, str(max_prob.item())[0:4],
            bottomLeftCornerOfText,
            font,
            0.8,
            [0,0,255],
            thickness,
            lineType)

        if not os.path.exists(output_root):
            os.makedirs(output_root)
        cv2.imwrite(os.path.join(output_root, 'region_retrieval.png'), out_image)


if __name__ == "__main__":
    main()
    sys.exit(0)