File size: 2,557 Bytes
684943d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
from argparse import ArgumentParser

import cv2
import numpy as np
import torch
from tqdm import tqdm


def extract_with_openseg(cfg):
    import tensorflow as tf2
    import tensorflow._api.v2.compat.v1 as tf
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
        except RuntimeError as e:
            print(e)

    openseg = tf2.saved_model.load(
        cfg.feature_extractor.model_path, 
        tags=[tf.saved_model.tag_constants.SERVING]
    )
    imgs_path = os.path.join(cfg.pipeline.data_path, "input")
    img_names = list(
        filter(
            lambda x: x.endswith("png") or x.endswith("jpg"), os.listdir(imgs_path)
        )
    )
    img_list = []
    np_image_string_list = []
    for img_name in img_names:
        img_path = os.path.join(imgs_path, img_name)
        image = cv2.imread(img_path)
        with tf.gfile.GFile(img_path, 'rb') as f:
            np_image_string = np.array([f.read()])

        image = torch.from_numpy(image)
        img_list.append(image)
        np_image_string_list.append(np_image_string)

    images = [img_list[i].permute(2, 0, 1)[None, ...] for i in range(len(img_list))]
    imgs = torch.cat(images)
    save_path = os.path.join(cfg.pipeline.data_path, "lang_features")
    os.makedirs(save_path, exist_ok=True)
    embed_size = 768
    for i, (img, np_image_string) in enumerate(tqdm((zip(imgs, np_image_string_list)), desc="Extracting lang features")):
        text_emb = tf.zeros([1, 1, embed_size])
        results = openseg.signatures["serving_default"](
            inp_image_bytes=tf.convert_to_tensor(np_image_string[0]),
            inp_text_emb=text_emb
        )
        img_info = results['image_info']
        crop_sz = [
            int(img_info[0, 0] * img_info[2, 0]),
            int(img_info[0, 1] * img_info[2, 1])
        ]
        image_embedding_feat = results['image_embedding_feat'][:, :crop_sz[0], :crop_sz[1]]
        img_size = (img.shape[1], img.shape[2])
        feat_2d = tf.cast(
            tf.image.resize_nearest_neighbor(
                image_embedding_feat, img_size, align_corners=True
            )[0], dtype=tf.float16
        ).numpy()
        # save feat_2d
        np.save(os.path.join(save_path, str(i+1).zfill(4)+".npy"), feat_2d)
    
if __name__ == "__main__":
    arg_parser = ArgumentParser()
    arg_parser.add_argument("--cfg")
    args = arg_parser.parse_args()
    extract_with_openseg(args.cfg)