skallewag commited on
Commit
6b59f61
·
verified ·
1 Parent(s): fff8061

Upload 16 files

Browse files
.gitattributes CHANGED
@@ -76,3 +76,11 @@ inference/images/region_retrieval.png filter=lfs diff=lfs merge=lfs -text
76
  inference/images/rose.webp filter=lfs diff=lfs merge=lfs -text
77
  inference/images/street.jpg filter=lfs diff=lfs merge=lfs -text
78
  inference/images/teaser_new.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
76
  inference/images/rose.webp filter=lfs diff=lfs merge=lfs -text
77
  inference/images/street.jpg filter=lfs diff=lfs merge=lfs -text
78
  inference/images/teaser_new.png filter=lfs diff=lfs merge=lfs -text
79
+ examples/fries1.png filter=lfs diff=lfs merge=lfs -text
80
+ examples/fries2.png filter=lfs diff=lfs merge=lfs -text
81
+ examples/minecraft1.jpg filter=lfs diff=lfs merge=lfs -text
82
+ examples/ref_vase.JPG filter=lfs diff=lfs merge=lfs -text
83
+ examples/river1.png filter=lfs diff=lfs merge=lfs -text
84
+ examples/river1.wav filter=lfs diff=lfs merge=lfs -text
85
+ examples/river2.png filter=lfs diff=lfs merge=lfs -text
86
+ examples/vasedeck.mp4 filter=lfs diff=lfs merge=lfs -text
examples/corgi1.webp ADDED
examples/corgi2.jpg ADDED
examples/fries1.png ADDED

Git LFS Details

  • SHA256: 3ed0360132103b859d1e58076fd40b88a1dcf06669344b69efa71ad04209bde9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
examples/fries2.png ADDED

Git LFS Details

  • SHA256: 3c5e86ca662f880135bb514978b2acee1fece23ffeba40c0cdf300171316b6ba
  • Pointer size: 132 Bytes
  • Size of remote file: 1.06 MB
examples/minecraft1.jpg ADDED

Git LFS Details

  • SHA256: 5b5440edc559e6e9724c3b95a5f7071ef3a6a1e982adbdffe431e45d94d72fad
  • Pointer size: 131 Bytes
  • Size of remote file: 120 kB
examples/placeholder.png ADDED
examples/ref_vase.JPG ADDED

Git LFS Details

  • SHA256: 3f5a75fc6567709c8fe250df8d287ca72435cdf04b474bfdba6f1cf7b5d2e4e6
  • Pointer size: 132 Bytes
  • Size of remote file: 3.54 MB
examples/river1.png ADDED

Git LFS Details

  • SHA256: aaa017dbfbf019357846e556908a032d508a48564fc97df4472c504ecba26f56
  • Pointer size: 131 Bytes
  • Size of remote file: 694 kB
examples/river1.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a71fa0c20c27f4ffe7567f437aec982877b5ccf34a7563d5603919bf6899a03a
3
+ size 397484
examples/river1_mask.png ADDED
examples/river2.png ADDED

Git LFS Details

  • SHA256: 51f602ddca840ac409283930b07a58ec617446ee825550dbb1ec4f0abe39d1f6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.03 MB
examples/vasedeck.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:726107c05e5837feb5c761714ef3eb2403b338392732ac10ff61969771cdd5a1
3
+ size 22498026
examples/zebras1.jpg ADDED
examples/zebras2.jpg ADDED
tasks/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .interactive import interactive_infer_video, interactive_infer_image
tasks/interactive.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # SEEM -- Segment Everything Everywhere All At Once
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Xueyan Zou ([email protected])
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import numpy as np
10
+ import torch.nn.functional as F
11
+ from PIL import Image
12
+ from torchvision import transforms
13
+ from utils.visualizer import Visualizer
14
+ from detectron2.utils.colormap import random_color
15
+ from detectron2.data import MetadataCatalog
16
+ from detectron2.structures import BitMasks
17
+ from modeling.language.loss import vl_similarity
18
+ from utils.constants import COCO_PANOPTIC_CLASSES
19
+ from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
20
+
21
+ import cv2
22
+ import os
23
+ import glob
24
+ import subprocess
25
+ from PIL import Image
26
+ import random
27
+
28
+ t = []
29
+ t.append(transforms.Resize(512, interpolation=Image.BICUBIC))
30
+ transform = transforms.Compose(t)
31
+ metadata = MetadataCatalog.get('coco_2017_train_panoptic')
32
+ all_classes = [name.replace('-other','').replace('-merged','') for name in COCO_PANOPTIC_CLASSES] + ["others"]
33
+ colors_list = [(np.array(color['color'])/255).tolist() for color in COCO_CATEGORIES] + [[1, 1, 1]]
34
+
35
+ def interactive_infer_image(model, audio_model, image, tasks, refimg=None, reftxt=None, audio_pth=None, video_pth=None):
36
+ image_ori = transform(image['image'])
37
+ mask_ori = image['mask']
38
+ width = image_ori.size[0]
39
+ height = image_ori.size[1]
40
+ image_ori = np.asarray(image_ori)
41
+ visual = Visualizer(image_ori, metadata=metadata)
42
+ images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
43
+
44
+ # stroke_inimg = None
45
+ # stroke_refimg = None
46
+
47
+ data = {"image": images, "height": height, "width": width}
48
+ if len(tasks) == 0:
49
+ tasks = ["Panoptic"]
50
+
51
+ # inistalize task
52
+ model.model.task_switch['spatial'] = False
53
+ model.model.task_switch['visual'] = False
54
+ model.model.task_switch['grounding'] = False
55
+ model.model.task_switch['audio'] = False
56
+
57
+ example = None
58
+ if 'Example' in tasks:
59
+ model.model.task_switch['visual'] = True
60
+ model.model.task_switch['spatial'] = True
61
+ refimg_ori, refimg_mask = refimg['image'], refimg['mask']
62
+ refimg_ori = transform(refimg_ori)
63
+ _width = refimg_ori.size[0]
64
+ _height = refimg_ori.size[1]
65
+ refimg_ori = np.asarray(refimg_ori)
66
+ refimg_ori_np = refimg_ori.copy()
67
+ images = torch.from_numpy(refimg_ori.copy()).permute(2,0,1).cuda()
68
+ batched_inputs = [{'image': images, 'height': _height, 'width': _width, 'spatial_query':{}}]
69
+
70
+ refimg_mask = np.asarray(refimg_mask)[:,:,0:1].copy()
71
+ refimg_mask = torch.from_numpy(refimg_mask).permute(2,0,1)[None,]
72
+ refimg_mask = (F.interpolate(refimg_mask, (_height, _width), mode='bilinear') > 0)
73
+ batched_inputs[0]['spatial_query']['rand_shape'] = refimg_mask
74
+ outputs_refimg, img_shape = model.model.evaluate_referring_image(batched_inputs)
75
+ model.model.task_switch['spatial'] = False
76
+ data['visual'] = outputs_refimg
77
+
78
+ # overlay = refimg_mask[0,0].float().numpy()[:,:,None] * np.array([0,0,255])
79
+ # x = refimg_ori_np
80
+ # stroke_refimg = x * (1 - refimg_mask[0,0].float().numpy()[:,:,None]) + (x * refimg_mask[0,0].numpy()[:,:,None] * 0.2 + overlay * 0.8)
81
+ # stroke_refimg = Image.fromarray(stroke_refimg.astype(np.uint8))
82
+
83
+ stroke = None
84
+ if 'Stroke' in tasks:
85
+ model.model.task_switch['spatial'] = True
86
+ mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
87
+ mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[None,]
88
+ mask_ori = (F.interpolate(mask_ori, (height, width), mode='bilinear') > 0)
89
+ data['stroke'] = mask_ori
90
+
91
+ # overlay = mask_ori[0,0].float().numpy()[:,:,None] * np.array([0,255,0])
92
+ # x = image_ori
93
+ # stroke_inimg = x * (1 - mask_ori[0,0].float().numpy()[:,:,None]) + (x * mask_ori[0,0].numpy()[:,:,None] * 0.2 + overlay * 0.8)
94
+ # stroke_inimg = Image.fromarray(stroke_inimg.astype(np.uint8))
95
+
96
+ text = None
97
+ if 'Text' in tasks:
98
+ model.model.task_switch['grounding'] = True
99
+ data['text'] = [reftxt]
100
+
101
+ audio = None
102
+ if 'Audio' in tasks:
103
+ model.model.task_switch['audio'] = True
104
+ audio_result = audio_model.transcribe(audio_pth)
105
+ data['audio'] = [audio_result['text']]
106
+
107
+ batch_inputs = [data]
108
+ if 'Panoptic' in tasks:
109
+ model.model.metadata = metadata
110
+ results = model.model.evaluate(batch_inputs)
111
+ pano_seg = results[-1]['panoptic_seg'][0]
112
+ pano_seg_info = results[-1]['panoptic_seg'][1]
113
+ demo = visual.draw_panoptic_seg(pano_seg.cpu(), pano_seg_info) # rgb Image
114
+ res = demo.get_image()
115
+ return Image.fromarray(res), None
116
+ else:
117
+ results,image_size,extra = model.model.evaluate_demo(batch_inputs)
118
+
119
+ # If contians spatial use spatial:
120
+ if 'Stroke' in tasks:
121
+ v_emb = results['pred_maskembs']
122
+ s_emb = results['pred_pspatials']
123
+ pred_masks = results['pred_masks']
124
+
125
+ pred_logits = v_emb @ s_emb.transpose(1,2)
126
+ logits_idx_y = pred_logits[:,:,0].max(dim=1)[1]
127
+ logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)
128
+ logits_idx = torch.stack([logits_idx_x, logits_idx_y]).tolist()
129
+ pred_masks_pos = pred_masks[logits_idx]
130
+ pred_class = results['pred_logits'][logits_idx].max(dim=-1)[1]
131
+
132
+ elif 'Example' in tasks:
133
+ v_emb = results['pred_maskembs']
134
+ s_emb = results['pred_pvisuals']
135
+ pred_masks = results['pred_masks']
136
+
137
+ pred_logits = v_emb @ s_emb.transpose(1,2)
138
+ logits_idx_y = pred_logits[:,:,0].max(dim=1)[1]
139
+ logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)
140
+ logits_idx = torch.stack([logits_idx_x, logits_idx_y]).tolist()
141
+ pred_masks_pos = pred_masks[logits_idx]
142
+ pred_class = results['pred_logits'][logits_idx].max(dim=-1)[1]
143
+
144
+ elif 'Text' in tasks:
145
+ pred_masks = results['pred_masks'][0]
146
+ v_emb = results['pred_captions'][0]
147
+ t_emb = extra['grounding_class']
148
+
149
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
150
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
151
+
152
+ temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale
153
+ out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
154
+
155
+ matched_id = out_prob.max(0)[1]
156
+ pred_masks_pos = pred_masks[matched_id,:,:]
157
+ pred_class = results['pred_logits'][0][matched_id].max(dim=-1)[1]
158
+
159
+ elif 'Audio' in tasks:
160
+ pred_masks = results['pred_masks'][0]
161
+ v_emb = results['pred_captions'][0]
162
+ t_emb = extra['audio_class']
163
+
164
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
165
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
166
+
167
+ temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale
168
+ out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
169
+
170
+ matched_id = out_prob.max(0)[1]
171
+ pred_masks_pos = pred_masks[matched_id,:,:]
172
+ pred_class = results['pred_logits'][0][matched_id].max(dim=-1)[1]
173
+
174
+ # interpolate mask to ori size
175
+ pred_masks_pos = (F.interpolate(pred_masks_pos[None,], image_size[-2:], mode='bilinear')[0,:,:data['height'],:data['width']] > 0.0).float().cpu().numpy()
176
+ texts = [all_classes[pred_class[0]]]
177
+
178
+ for idx, mask in enumerate(pred_masks_pos):
179
+ # color = random_color(rgb=True, maximum=1).astype(np.int32).tolist()
180
+ out_txt = texts[idx] if 'Text' not in tasks else reftxt
181
+ demo = visual.draw_binary_mask(mask, color=colors_list[pred_class[0]%133], text=out_txt)
182
+ res = demo.get_image()
183
+ torch.cuda.empty_cache()
184
+ # return Image.fromarray(res), stroke_inimg, stroke_refimg
185
+ return Image.fromarray(res), None
186
+
187
+ def interactive_infer_video(model, audio_model, image, tasks, refimg=None, reftxt=None, audio_pth=None, video_pth=None):
188
+ if 'Video' in tasks:
189
+ input_dir = video_pth.replace('.mp4', '')
190
+ input_name = input_dir.split('/')[-1]
191
+ random_number = str(random.randint(10000, 99999))
192
+ output_dir = input_dir + '_output'
193
+ output_name = output_dir.split('/')[-1]
194
+ output_file = video_pth.replace('.mp4', '_{}_output.mp4'.format(random_number))
195
+ frame_interval = 10
196
+
197
+ # Ensure output directory exists
198
+ if not os.path.exists(input_dir):
199
+ os.makedirs(input_dir)
200
+
201
+ if not os.path.exists(output_dir):
202
+ os.makedirs(output_dir)
203
+
204
+ # Build the FFmpeg command
205
+ ffmpeg_cmd = "ffmpeg -i {} -vf \"fps=5\" {}/%04d.png".format(video_pth, input_dir)
206
+ os.system(ffmpeg_cmd)
207
+
208
+ data = {}
209
+ model.model.task_switch['visual'] = True
210
+ model.model.task_switch['spatial'] = True
211
+ refimg_ori, refimg_mask = refimg['image'], refimg['mask']
212
+ refimg_ori = transform(refimg_ori)
213
+ _width = refimg_ori.size[0]
214
+ _height = refimg_ori.size[1]
215
+ refimg_ori = np.asarray(refimg_ori)
216
+ refimg_ori_np = refimg_ori.copy()
217
+ images = torch.from_numpy(refimg_ori.copy()).permute(2,0,1).cuda()
218
+ batched_inputs = [{'image': images, 'height': _height, 'width': _width, 'spatial_query':{}}]
219
+
220
+ refimg_mask = np.asarray(refimg_mask)[:,:,0:1].copy()
221
+ refimg_mask = torch.from_numpy(refimg_mask).permute(2,0,1)[None,]
222
+ refimg_mask = (F.interpolate(refimg_mask, (_height, _width), mode='bilinear') > 0)
223
+ batched_inputs[0]['spatial_query']['rand_shape'] = refimg_mask
224
+ outputs_refimg, img_shape = model.model.evaluate_referring_image(batched_inputs)
225
+ model.model.task_switch['visual'] = False
226
+ model.model.task_switch['spatial'] = False
227
+ data['visual'] = outputs_refimg
228
+
229
+ model.model.task_switch['visual'] = True
230
+ frame_pths = sorted(glob.glob(os.path.join(input_dir, '*.png')))
231
+ for frame_pth in frame_pths:
232
+ image_ori = transform(Image.open(frame_pth))
233
+ width = image_ori.size[0]
234
+ height = image_ori.size[1]
235
+ image_ori = np.asarray(image_ori)
236
+ visual = Visualizer(image_ori[:,:,::-1], metadata=metadata)
237
+ images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
238
+
239
+ data.update({"image": images, "height": height, "width": width})
240
+ batch_inputs = [data]
241
+ results,image_size,extra = model.model.evaluate_demo(batch_inputs)
242
+
243
+ v_emb = results['pred_maskembs']
244
+ s_emb = results['pred_pvisuals']
245
+ pred_masks = results['pred_masks']
246
+
247
+ pred_logits = v_emb @ s_emb.transpose(1,2)
248
+ logits_idx_y = pred_logits[:,:,0].max(dim=1)[1]
249
+ logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)
250
+ logits_idx = torch.stack([logits_idx_x, logits_idx_y]).tolist()
251
+ pred_masks_pos = pred_masks[logits_idx]
252
+ pred_class = results['pred_logits'][logits_idx].max(dim=-1)[1]
253
+
254
+ pred_masks_pos = (F.interpolate(pred_masks_pos[None,], image_size[-2:], mode='bilinear')[0,:,:data['height'],:data['width']] > 0.0).float().cpu().numpy()
255
+ texts = [all_classes[pred_class[0]]]
256
+
257
+ for idx, mask in enumerate(pred_masks_pos):
258
+ out_txt = texts[idx]
259
+ demo = visual.draw_binary_mask(mask, color=colors_list[pred_class[0]%133], text=out_txt)
260
+
261
+ res = demo.get_image()
262
+ output_pth = frame_pth.replace(input_name, output_name)
263
+ cv2.imwrite(output_pth, res)
264
+
265
+ ffmpeg_cmd = "ffmpeg -framerate 5 -pattern_type glob -i '{}/*.png' -c:v libx264 {}".format(output_dir, output_file)
266
+ os.system(ffmpeg_cmd)
267
+
268
+ return None, output_file