This view is limited to 50 files because it contains too many changes.  See the raw diff here.
Files changed (50) hide show
  1. app.py +387 -0
  2. requirements.txt +44 -0
  3. rynnec/__init__.py +269 -0
  4. rynnec/constants.py +47 -0
  5. rynnec/mm_utils.py +733 -0
  6. rynnec/model/__init__.py +196 -0
  7. rynnec/model/encoder.py +282 -0
  8. rynnec/model/extension/__init__.py +1 -0
  9. rynnec/model/extension/sam2_base.py +298 -0
  10. rynnec/model/loss.py +597 -0
  11. rynnec/model/predictor/__init__.py +1 -0
  12. rynnec/model/predictor/sam2_predictor.py +724 -0
  13. rynnec/model/processor.py +401 -0
  14. rynnec/model/projector.py +161 -0
  15. rynnec/model/region_encoder.py +77 -0
  16. rynnec/model/rynnec_arch.py +271 -0
  17. rynnec/model/rynnec_qwen2.py +638 -0
  18. rynnec/model/sam2.py +133 -0
  19. rynnec/model/sam2_train.py +134 -0
  20. rynnec/model/utils.py +61 -0
  21. rynnec/model/videollama3_encoder/__init__.py +3 -0
  22. rynnec/model/videollama3_encoder/configuration_videollama3_encoder.py +71 -0
  23. rynnec/model/videollama3_encoder/image_processing_videollama3.py +473 -0
  24. rynnec/model/videollama3_encoder/modeling_videollama3_encoder.py +555 -0
  25. rynnec/rynnec_trainer.py +496 -0
  26. rynnec/train.py +832 -0
  27. third_parts/sam2/__init__.py +9 -0
  28. third_parts/sam2/automatic_mask_generator.py +434 -0
  29. third_parts/sam2/build_sam.py +89 -0
  30. third_parts/sam2/csrc/connected_components.cu +289 -0
  31. third_parts/sam2/modeling/__init__.py +5 -0
  32. third_parts/sam2/modeling/backbones/__init__.py +5 -0
  33. third_parts/sam2/modeling/backbones/hieradet.py +295 -0
  34. third_parts/sam2/modeling/backbones/image_encoder.py +133 -0
  35. third_parts/sam2/modeling/backbones/utils.py +95 -0
  36. third_parts/sam2/modeling/memory_attention.py +169 -0
  37. third_parts/sam2/modeling/memory_encoder.py +181 -0
  38. third_parts/sam2/modeling/position_encoding.py +221 -0
  39. third_parts/sam2/modeling/sam/__init__.py +5 -0
  40. third_parts/sam2/modeling/sam/mask_decoder.py +299 -0
  41. third_parts/sam2/modeling/sam/prompt_encoder.py +182 -0
  42. third_parts/sam2/modeling/sam/transformer.py +328 -0
  43. third_parts/sam2/modeling/sam2_base.py +830 -0
  44. third_parts/sam2/modeling/sam2_utils.py +149 -0
  45. third_parts/sam2/sam2_configs/__init__.py +5 -0
  46. third_parts/sam2/sam2_configs/sam2_hiera_b+.yaml +113 -0
  47. third_parts/sam2/sam2_configs/sam2_hiera_l.yaml +117 -0
  48. third_parts/sam2/sam2_configs/sam2_hiera_s.yaml +116 -0
  49. third_parts/sam2/sam2_configs/sam2_hiera_t.yaml +118 -0
  50. third_parts/sam2/sam2_image_predictor.py +446 -0
app.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import cv2
4
+ import torch
5
+ import gradio as gr
6
+ from transformers import SamModel, SamProcessor
7
+
8
+ import spaces
9
+ import numpy as np
10
+ from PIL import Image
11
+ from tqdm import tqdm
12
+ from torchvision.transforms import v2
13
+
14
+ from rynnec import disable_torch_init, model_init, mm_infer, mm_infer_segmentation
15
+ from rynnec.mm_utils import annToMask, load_video, load_images
16
+
17
+ from PIL import Image
18
+ from tqdm import tqdm
19
+ import numpy as np
20
+ import colorsys
21
+ import argparse
22
+
23
+
24
+ def get_hsv_palette(n_colors):
25
+ hues = np.linspace(0, 1, int(n_colors) + 1)[1:-1]
26
+ s = 0.8
27
+ v = 0.9
28
+ palette = [(0.0, 0.0, 0.0)] + [
29
+ colorsys.hsv_to_rgb(h_i, s, v) for h_i in hues
30
+ ]
31
+ return (255 * np.asarray(palette)).astype("uint8")
32
+
33
+
34
+ def colorize_masks(images, index_masks, fac: float = 0.8, draw_contour=True, edge_thickness=20):
35
+ max_idx = max([m.max() for m in index_masks])
36
+ palette = get_hsv_palette(max_idx + 1)
37
+ color_masks = []
38
+ out_frames = []
39
+ for img, mask in tqdm(zip(images, index_masks), desc='Visualize masks ...'):
40
+ clr_mask = palette[mask.astype("int")]
41
+ blended_img = img
42
+
43
+ blended_img = compose_img_mask(blended_img, clr_mask, fac)
44
+
45
+ if draw_contour:
46
+ blended_img = draw_contours_on_image(blended_img, mask, clr_mask,
47
+ brightness_factor=1.8,
48
+ alpha=0.6,
49
+ thickness=edge_thickness)
50
+ out_frames.append(blended_img)
51
+
52
+ return out_frames, color_masks
53
+
54
+
55
+ def compose_img_mask(img, color_mask, fac: float = 0.5):
56
+ mask_region = (color_mask.sum(axis=-1) > 0)[..., None]
57
+ out_f = img.copy() / 255
58
+ out_f[mask_region[:, :, 0]] = fac * img[mask_region[:, :, 0]] / 255 + (1 - fac) * color_mask[mask_region[:, :, 0]] / 255
59
+ out_u = (255 * out_f).astype("uint8")
60
+ return out_u
61
+
62
+
63
+ def draw_contours_on_image(img, index_mask, color_mask, brightness_factor=1.6, alpha=0.5, thickness=2, ignore_index=0):
64
+ img = img.astype("float32")
65
+ overlay = img.copy()
66
+
67
+ unique_indices = np.unique(index_mask)
68
+ if ignore_index is not None:
69
+ unique_indices = [idx for idx in unique_indices if idx != ignore_index]
70
+
71
+ for i in unique_indices:
72
+ bin_mask = (index_mask == i).astype("uint8") * 255
73
+ if bin_mask.sum() == 0:
74
+ continue
75
+
76
+ contours, _ = cv2.findContours(bin_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
77
+
78
+ color = color_mask[index_mask == i][0].astype("float32")
79
+ bright_color = np.clip(color * brightness_factor, 0, 255).tolist()
80
+
81
+ cv2.drawContours(overlay, contours, -1, bright_color, thickness)
82
+
83
+ blended = (1 - alpha) * img + alpha * overlay
84
+ return np.clip(blended, 0, 255).astype("uint8")
85
+
86
+
87
+ def extract_first_frame_from_video(video):
88
+ cap = cv2.VideoCapture(video)
89
+ success, frame = cap.read()
90
+ cap.release()
91
+ if success:
92
+ return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
93
+ return None
94
+
95
+
96
+ def extract_points_from_mask(mask_pil):
97
+ mask = np.asarray(mask_pil)[..., 0]
98
+ coords = np.nonzero(mask)
99
+ coords = np.stack((coords[1], coords[0]), axis=1)
100
+
101
+ return coords
102
+
103
+ def add_contour(img, mask, color=(1., 1., 1.)):
104
+ img = img.copy()
105
+
106
+ mask = mask.astype(np.uint8) * 255
107
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
108
+ cv2.drawContours(img, contours, -1, color, thickness=8)
109
+
110
+ return img
111
+
112
+
113
+ def load_first_frame(video_path):
114
+ cap = cv2.VideoCapture(video_path)
115
+ ret, frame = cap.read()
116
+ cap.release()
117
+ if not ret:
118
+ raise gr.Error("Could not read the video file.")
119
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
120
+ image = Image.fromarray(frame)
121
+ return image
122
+
123
+
124
+ def clear_masks():
125
+ return [], [], [], []
126
+
127
+ def clear_all():
128
+ return [], [], [], [], None, "", ""
129
+
130
+
131
+ @spaces.GPU(duration=120)
132
+ def apply_sam(image, input_points):
133
+ inputs = sam_processor(image, input_points=input_points, return_tensors="pt").to(device)
134
+
135
+ with torch.no_grad():
136
+ outputs = sam_model(**inputs)
137
+
138
+ masks = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())[0][0]
139
+ scores = outputs.iou_scores[0, 0]
140
+
141
+ mask_selection_index = scores.argmax()
142
+
143
+ mask_np = masks[mask_selection_index].numpy()
144
+
145
+ return mask_np
146
+
147
+
148
+ @spaces.GPU(duration=120)
149
+ def run(mode, images, timestamps, masks, mask_ids, instruction, mask_output_video):
150
+ if mode == "QA":
151
+ response = run_text_inference(images, timestamps, masks, mask_ids, instruction)
152
+ else:
153
+ response, mask_output_video = run_seg_inference(images, timestamps, instruction)
154
+ return response, mask_output_video
155
+
156
+
157
+ def run_text_inference(images, timestamps, masks, mask_ids, instruction):
158
+ masks = torch.from_numpy(np.stack(masks, axis=0))
159
+
160
+ if "<video>" not in instruction:
161
+ instruction = "<video>\n" + instruction
162
+
163
+ if len(masks) >= 2:
164
+ obj_str = f"<video>\nThere are {len(masks)} objects in the video: " + ", ".join([f"<object{i}> [<REGION>]" for i in range(len(masks))])
165
+ instruction = instruction.replace("<video>\n", obj_str)
166
+ else:
167
+ instruction = instruction.replace("<object0>", '[<REGION>]')
168
+
169
+ output = mm_infer(
170
+ (images, timestamps),
171
+ processor,
172
+ instruction,
173
+ model=model,
174
+ tokenizer=processor.tokenizer,
175
+ do_sample=False,
176
+ modal='video',
177
+ masks=masks.cuda() if masks is not None else None,
178
+ mask_ids=mask_ids
179
+ )
180
+
181
+ return output
182
+
183
+
184
+ def run_seg_inference(images, timestamps, instruction):
185
+ output, masks = mm_infer_segmentation(
186
+ (images, timestamps),
187
+ processor,
188
+ instruction,
189
+ model=model,
190
+ tokenizer=processor.tokenizer,
191
+ do_sample=False,
192
+ modal='video',
193
+ )
194
+
195
+ w, h = images[0].size
196
+ masks = v2.Resize([h, w])(masks).cpu().numpy()
197
+
198
+ mask_list_video = []
199
+
200
+ images = [np.array(image) for image in images]
201
+ masks = [mask[0] for mask in masks]
202
+ show_images, _ = colorize_masks(images, masks)
203
+ for i, image in enumerate(show_images):
204
+ if masks[i].sum() > 1000:
205
+ mask_list_video.append((Image.fromarray(image), f"Frame {i}"))
206
+
207
+ return output, mask_list_video
208
+
209
+
210
+ def generate_masks_video(image, mask_list_video, mask_raw_list_video, mask_ids, frame_idx):
211
+ image['image'] = image['background'].convert('RGB')
212
+ # del image['background'], image['composite']
213
+ assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
214
+
215
+ mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB')
216
+ points = extract_points_from_mask(mask)
217
+ np.random.seed(0)
218
+ if points.shape[0] == 0:
219
+ raise gr.Error("No points selected")
220
+
221
+ points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False)
222
+ points = points[points_selected_indices]
223
+ coords = [points.tolist()]
224
+ mask_np = apply_sam(image['image'], coords)
225
+
226
+ mask_raw_list_video.append(mask_np)
227
+ mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(image['image'])).astype(np.uint8))
228
+
229
+ mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>"))
230
+ # Return a list containing the mask image.
231
+ image['layers'] = []
232
+ image['composite'] = image['background']
233
+ mask_ids.append(frame_idx)
234
+ return mask_list_video, image, mask_list_video, mask_raw_list_video, mask_ids
235
+
236
+
237
+ if __name__ == "__main__":
238
+ parser = argparse.ArgumentParser(description="VideoRefer gradio demo")
239
+ parser.add_argument("--model-path", type=str, default="Alibaba-DAMO-Academy/RynnEC-2B", help="Path to the model checkpoint")
240
+ parser.add_argument("--port", type=int, default=4001)
241
+
242
+ args_cli = parser.parse_args()
243
+
244
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="amber")) as demo:
245
+
246
+ mask_list = gr.State([])
247
+ mask_raw_list = gr.State([])
248
+ mask_list_video = gr.State([])
249
+ mask_raw_list_video = gr.State([])
250
+
251
+
252
+ HEADER = ("""
253
+ <div>
254
+ <h1>RynnEC Demo</h1>
255
+ <h5 style="margin: 0;">Feel free to click on anything that grabs your interest!</h5>
256
+ <h5 style="margin: 0;">If this demo please you, please give us a star ⭐ on Github or 💖 on this space.</h5>
257
+ </div>
258
+ </div>
259
+ <div style="display: flex; justify-content: left; margin-top: 10px;">
260
+ <a href="https://arxiv.org/pdf/2501.00599"><img src="https://img.shields.io/badge/Arxiv-2501.00599-ECA8A7" style="margin-right: 5px;"></a>
261
+ <a href="https://github.com/DAMO-NLP-SG/VideoRefer"><img src='https://img.shields.io/badge/Github-VideoRefer-F7C97E' style="margin-right: 5px;"></a>
262
+ <a href="https://github.com/DAMO-NLP-SG/VideoLLaMA3"><img src='https://img.shields.io/badge/Github-VideoLLaMA3-9DC3E6' style="margin-right: 5px;"></a>
263
+ </div>
264
+ """)
265
+
266
+
267
+ image_tips = """
268
+ ### 💡 Tips:
269
+
270
+ 🧸 Upload an image, and you can use the drawing tool✍️ to highlight the areas you're interested in.
271
+
272
+ 🔖 For single-object caption mode, simply select the area and click the 'Generate Caption' button to receive a caption for the object.
273
+
274
+ 🔔 In QA mode, you can generate multiple masks by clicking the 'Generate Mask' button multiple times. Afterward, use the corresponding object id to ask questions.
275
+
276
+ 📌 Click the button 'Clear Masks' to clear the current generated masks.
277
+
278
+ """
279
+
280
+ video_tips = """
281
+ ### 💡 Tips:
282
+ 🧸 Upload an video, and you can use the drawing tool✍️ to highlight the areas you're interested in the first frame.
283
+
284
+ 🔔 In QA mode, you can generate multiple masks by clicking the 'Generate Mask' button multiple times. Afterward, use the corresponding object id to ask questions.
285
+
286
+ 📌 Click the button 'Clear Masks' to clear the current generated masks.
287
+
288
+ """
289
+
290
+
291
+ with gr.TabItem("Video"):
292
+ with gr.Row():
293
+ with gr.Column():
294
+ video_input = gr.Video(label="Video", interactive=True)
295
+ frame_idx = gr.Slider(minimum=0, maximum=0, value=0, step=1, label="Select Frame", interactive=False)
296
+ selected_frame = gr.ImageEditor(
297
+ label="Annotate Frame",
298
+ type="pil",
299
+ sources=[],
300
+ interactive=True,
301
+ )
302
+ generate_mask_btn_video = gr.Button("1️⃣ Generate Mask", visible=True, variant="primary")
303
+ gr.Examples([f"./demo/videos/{i+1}.mp4" for i in range(4)], inputs=video_input, label="Examples")
304
+
305
+ with gr.Column():
306
+ mode_video = gr.Radio(label="Mode", choices=["QA", "Seg"], value="QA")
307
+ mask_output_video = gr.Gallery(label="Referred Masks", object_fit='scale-down')
308
+
309
+ query_video = gr.Textbox(label="Question", value="Please describe <object0>.", interactive=True, visible=True)
310
+ response_video = gr.Textbox(label="Answer", interactive=False)
311
+
312
+ submit_btn_video = gr.Button("Generate Caption", variant="primary", visible=False)
313
+ submit_btn_video1 = gr.Button("2️⃣ Generate Answer", variant="primary", visible=True)
314
+ description_video = gr.Textbox(label="Output", visible=False)
315
+
316
+ clear_masks_btn_video = gr.Button("Clear Masks", variant="secondary")
317
+
318
+ gr.Markdown(video_tips)
319
+
320
+ frames = gr.State(value=[])
321
+ timestamps = gr.State(value=[])
322
+ mask_ids = gr.State(value=[])
323
+
324
+ def on_video_upload(video_path):
325
+ frames, timestamps = load_video(video_path, fps=1, max_frames=128)
326
+ frames = [Image.fromarray(x.transpose(1, 2, 0)) for x in frames]
327
+ return frames, timestamps, frames[0], gr.update(value=0, maximum=len(frames) - 1, interactive=True)
328
+
329
+ def on_frame_idx_change(frame_idx, frames):
330
+ return frames[frame_idx]
331
+
332
+ def to_seg_mode():
333
+ return (
334
+ *[gr.update(visible=False) for _ in range(4)],
335
+ []
336
+ )
337
+
338
+ def to_qa_mode():
339
+ return (
340
+ *[gr.update(visible=True) for _ in range(4)],
341
+ []
342
+ )
343
+
344
+ def on_mode_change(mode):
345
+ if mode == "QA":
346
+ return to_qa_mode()
347
+ return to_seg_mode()
348
+
349
+ mode_video.change(on_mode_change, inputs=[mode_video], outputs=[frame_idx, selected_frame, generate_mask_btn_video, response_video, mask_output_video])
350
+ video_input.change(on_video_upload, inputs=[video_input], outputs=[frames, timestamps, selected_frame, frame_idx])
351
+ frame_idx.change(on_frame_idx_change, inputs=[frame_idx, frames], outputs=[selected_frame])
352
+
353
+ generate_mask_btn_video.click(
354
+ fn=generate_masks_video,
355
+ inputs=[selected_frame, mask_list_video, mask_raw_list_video, mask_ids, frame_idx],
356
+ outputs=[mask_output_video, selected_frame, mask_list_video, mask_raw_list_video, mask_ids]
357
+ )
358
+
359
+ submit_btn_video1.click(
360
+ fn=run,
361
+ inputs=[mode_video, frames, timestamps, mask_raw_list_video, mask_ids, query_video, mask_output_video],
362
+ outputs=[response_video, mask_output_video],
363
+ api_name="describe_video"
364
+ )
365
+
366
+ video_input.clear(
367
+ fn=clear_all,
368
+ outputs=[mask_output_video, mask_list_video, mask_raw_list_video, mask_ids, selected_frame, query_video, response_video]
369
+ )
370
+
371
+ clear_masks_btn_video.click(
372
+ fn=clear_masks,
373
+ outputs=[mask_output_video, mask_list_video, mask_raw_list_video, mask_ids]
374
+ )
375
+
376
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
377
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
378
+ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
379
+ # sam_model = sam_processor = None
380
+ disable_torch_init()
381
+ model, processor = model_init(args_cli.model_path)
382
+ # model = processor = None
383
+
384
+ # demo.launch()
385
+ demo.launch(
386
+ share=False,
387
+ )
requirements.txt ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu124
2
+ # basic dependencies
3
+ torch==2.4.0
4
+ torchvision==0.19.0
5
+ datasets==2.21.0
6
+ transformers==4.46.3
7
+ tokenizers==0.20.3
8
+ deepspeed==0.15.4
9
+ accelerate==1.0.1
10
+ peft==0.4.0
11
+ timm==1.0.3
12
+ numpy==1.24.4
13
+ # data processing
14
+ decord==0.6.0
15
+ imageio==2.34.0
16
+ imageio-ffmpeg==0.4.9
17
+ moviepy==1.0.3
18
+ opencv-python==4.6.0.66
19
+ pyarrow
20
+ pysubs2
21
+ ffmpeg-python
22
+ # misc
23
+ scikit-learn==1.2.2
24
+ huggingface_hub==0.23.4
25
+ sentencepiece==0.1.99
26
+ shortuuid
27
+ einops==0.6.1
28
+ einops-exts==0.0.4
29
+ bitsandbytes==0.43.3 # for cuda 124
30
+ pydantic>=2.0
31
+ markdown2[all]
32
+ gradio==3.50.0
33
+ gradio_client==0.6.1
34
+ httpx==0.24.1
35
+ requests
36
+ openai
37
+ uvicorn
38
+ fastapi
39
+ tensorboard
40
+ wandb
41
+ tabulate
42
+ hydra-core
43
+ pycocotools==2.0.10
44
+ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
rynnec/__init__.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import math
4
+ import warnings
5
+ import shutil
6
+ from functools import partial
7
+
8
+ import torch
9
+ import numpy as np
10
+ from .model import load_pretrained_model
11
+ from .mm_utils import load_images, process_images, load_video, process_video, tokenizer_multimodal_token, get_model_name_from_path, KeywordsStoppingCriteria, DirectResize, sam_preprocess_batch
12
+ from .constants import NUM_FRAMES, DEFAULT_IMAGE_TOKEN, DEFAULT_VIDEO_TOKEN, MODAL_INDEX_MAP, STREAM_START_TOKEN, STREAM_END_TOKEN
13
+ from .model.rynnec_qwen2 import Videollama3Qwen2Processor
14
+
15
+ def disable_torch_init():
16
+ """
17
+ Disable the redundant torch default initialization to accelerate model creation.
18
+ """
19
+ import torch
20
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
21
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
22
+
23
+
24
+ def model_init(model_path=None, min_visual_tokens=None, max_visual_tokens=None, **kwargs):
25
+ model_path = "Alibaba-DAMO-Academy/RynnEC-2B" if model_path is None else model_path
26
+ model_name = get_model_name_from_path(model_path)
27
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, **kwargs)
28
+
29
+ if max_visual_tokens is not None:
30
+ image_processor.max_tokens = max_visual_tokens
31
+ if min_visual_tokens is not None:
32
+ image_processor.min_tokens = min_visual_tokens
33
+
34
+ if tokenizer.pad_token is None and tokenizer.unk_token is not None:
35
+ tokenizer.pad_token = tokenizer.unk_token
36
+
37
+ processor = Videollama3Qwen2Processor(image_processor, tokenizer)
38
+
39
+ return model, processor
40
+
41
+
42
+ def mm_infer(images_or_videos, vlprocessor, instruct, model, tokenizer, modal='video', **kwargs):
43
+
44
+ mask_ids = kwargs.pop('mask_ids', None)
45
+ masks = kwargs.pop('masks', None)
46
+ if modal == 'image':
47
+ modal_token = DEFAULT_IMAGE_TOKEN
48
+ images = images_or_videos
49
+ timestamps = None
50
+ elif modal == 'video':
51
+ modal_token = DEFAULT_VIDEO_TOKEN
52
+ images, timestamps = images_or_videos
53
+ elif modal == 'text':
54
+ modal_token = ''
55
+ else:
56
+ raise ValueError(f"Unsupported modal: {modal}")
57
+
58
+
59
+ # 1. text preprocess (tag process & generate prompt).
60
+ if isinstance(instruct, str):
61
+ messages = [{'role': 'user', 'content': instruct}]
62
+ elif isinstance(instruct, list):
63
+ messages = copy.deepcopy(instruct)
64
+ else:
65
+ raise ValueError(f"Unsupported type of instruct: {type(instruct)}")
66
+
67
+ if all(not modal_token in message["content"] for message in messages):
68
+ warnings.warn(f"Image tag not found in the conversation, add it automatically at the beginning!")
69
+ messages[0]["content"] = modal_token + messages[0]["content"]
70
+
71
+ converted_messages = []
72
+ for message in messages:
73
+ chunks = message["content"].split(modal_token)
74
+ converted_messages.append({
75
+ "role": "user",
76
+ "content": []
77
+ })
78
+
79
+ for chunk_idx in range(1, 2 * len(chunks)):
80
+ if chunk_idx % 2 == 1:
81
+ chunk = chunks[chunk_idx // 2].strip()
82
+ converted_messages[-1]["content"].append({"type": "text", "text": chunk}) if chunk else None
83
+ else:
84
+ if modal == 'image':
85
+ converted_messages[-1]["content"].append({"type": "image"})
86
+ elif modal == 'video':
87
+ converted_messages[-1]["content"].append({"type": "video", "num_frames": len(images), "time": timestamps})
88
+
89
+ messages = converted_messages
90
+
91
+ system_message = []
92
+
93
+ image_downsampling = kwargs.get('image_downsampling', model.config.spatial_merge_size)
94
+ # TODO: attention mask?
95
+ messages = system_message + messages
96
+ data_dict = vlprocessor(
97
+ images=images,
98
+ text=messages,
99
+ merge_size=image_downsampling,
100
+ return_labels=True,
101
+ return_tensors="pt",
102
+ )
103
+
104
+ torch_dtype = model.config.torch_dtype if hasattr(model.config, "torch_dtype") else torch.float16
105
+
106
+ # images = [x.to(torch_dtype).cuda(non_blocking=True) for x in data_dict["images"]]
107
+ # grid_thws = [x.cuda(non_blocking=True) for x in data_dict["grid_thws"]]
108
+
109
+ # 3. generate response according to visual signals and prompts.
110
+ keywords = [tokenizer.eos_token]
111
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, data_dict["input_ids"].unsqueeze(0))
112
+
113
+ do_sample = kwargs.get('do_sample', False)
114
+ temperature = kwargs.get('temperature', 0.2 if do_sample else 1.0)
115
+ top_p = kwargs.get('top_p', 0.9 if do_sample else 1.0)
116
+ top_k = kwargs.get('top_k', 20 if do_sample else 50)
117
+ max_new_tokens = kwargs.get('max_new_tokens', 2048)
118
+
119
+ data_dict["modals"] = [modal]
120
+ data_dict = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data_dict.items()}
121
+ if "pixel_values" in data_dict:
122
+ data_dict["modals"] = data_dict["modals"] * len(data_dict["grid_sizes"])
123
+ data_dict["pixel_values"] = data_dict["pixel_values"].to(torch.bfloat16)
124
+
125
+ with torch.inference_mode():
126
+ output_ids = model.generate(
127
+ input_ids=data_dict["input_ids"].unsqueeze(0).cuda(),
128
+ pixel_values=data_dict["pixel_values"],
129
+ grid_sizes=data_dict["grid_sizes"],
130
+ merge_sizes=data_dict["merge_sizes"],
131
+ modals=data_dict["modals"],
132
+ do_sample=do_sample,
133
+ temperature=temperature,
134
+ max_new_tokens=max_new_tokens,
135
+ top_p=top_p,
136
+ top_k=top_k,
137
+ use_cache=True,
138
+ stopping_criteria=[stopping_criteria],
139
+ pad_token_id=tokenizer.eos_token_id,
140
+ masks=[masks],
141
+ mask_ids=mask_ids
142
+ )
143
+
144
+ outputs = tokenizer.decode(output_ids[0], skip_special_tokens=True)
145
+
146
+ return outputs
147
+
148
+ def mm_infer_segmentation(images_or_videos, vlprocessor, instruct, model, tokenizer, modal='video', seg_start_idx=0, **kwargs):
149
+
150
+ image2maskids = kwargs.get('image2maskids', [])
151
+ img_size=1024
152
+ sam_transform = DirectResize(img_size)
153
+
154
+
155
+ if modal == 'image':
156
+ modal_token = DEFAULT_IMAGE_TOKEN
157
+ images = images_or_videos
158
+ timestamps = None
159
+ elif modal == 'video':
160
+ modal_token = DEFAULT_VIDEO_TOKEN
161
+ images, timestamps = images_or_videos
162
+ elif modal == 'text':
163
+ modal_token = ''
164
+ else:
165
+ raise ValueError(f"Unsupported modal: {modal}")
166
+
167
+
168
+ sam_images = []
169
+ sam_size = None
170
+ if len(images)>0:
171
+ for image in images:
172
+ sam_image = sam_transform.apply_image(np.array(image))
173
+ sam_images.append(sam_image)
174
+ if sam_size is None:
175
+ sam_size = sam_image.shape[:2]
176
+ sam_images = np.array(sam_images)
177
+ sam_images = torch.from_numpy(sam_images).permute(0, 3, 1, 2).contiguous()
178
+ sam_images = sam_preprocess_batch(sam_images)
179
+
180
+
181
+ # 1. text preprocess (tag process & generate prompt).
182
+ if isinstance(instruct, str):
183
+ messages = [{'role': 'user', 'content': instruct}]
184
+ elif isinstance(instruct, list):
185
+ messages = copy.deepcopy(instruct)
186
+ else:
187
+ raise ValueError(f"Unsupported type of instruct: {type(instruct)}")
188
+
189
+ if all(not modal_token in message["content"] for message in messages):
190
+ warnings.warn(f"Image tag not found in the conversation, add it automatically at the beginning!")
191
+ messages[0]["content"] = modal_token + messages[0]["content"]
192
+
193
+ converted_messages = []
194
+ for message in messages:
195
+ chunks = message["content"].split(modal_token)
196
+ converted_messages.append({
197
+ "role": "user",
198
+ "content": []
199
+ })
200
+
201
+ for chunk_idx in range(1, 2 * len(chunks)):
202
+ if chunk_idx % 2 == 1:
203
+ chunk = chunks[chunk_idx // 2].strip()
204
+ converted_messages[-1]["content"].append({"type": "text", "text": chunk}) if chunk else None
205
+ else:
206
+ if modal == 'image':
207
+ converted_messages[-1]["content"].append({"type": "image"})
208
+ elif modal == 'video':
209
+ converted_messages[-1]["content"].append({"type": "video", "num_frames": len(images), "time": timestamps})
210
+
211
+ messages = converted_messages
212
+
213
+ system_message = []
214
+
215
+ image_downsampling = kwargs.get('image_downsampling', model.config.spatial_merge_size)
216
+ # TODO: attention mask?
217
+ messages = system_message + messages
218
+ data_dict = vlprocessor(
219
+ images=images,
220
+ text=messages,
221
+ merge_size=image_downsampling,
222
+ return_labels=True,
223
+ return_tensors="pt",
224
+ )
225
+
226
+ torch_dtype = model.config.torch_dtype if hasattr(model.config, "torch_dtype") else torch.float16
227
+
228
+ keywords = [tokenizer.eos_token]
229
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, data_dict["input_ids"].unsqueeze(0))
230
+
231
+ do_sample = kwargs.get('do_sample', False)
232
+ temperature = kwargs.get('temperature', 0.2 if do_sample else 1.0)
233
+ top_p = kwargs.get('top_p', 0.9 if do_sample else 1.0)
234
+ top_k = kwargs.get('top_k', 20 if do_sample else 50)
235
+ max_new_tokens = kwargs.get('max_new_tokens', 2048)
236
+
237
+ torch_dtype = model.config.torch_dtype if hasattr(model.config, "torch_dtype") else torch.float16
238
+
239
+ data_dict["modals"] = [modal]
240
+ data_dict = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data_dict.items()}
241
+ if "pixel_values" in data_dict:
242
+ data_dict["modals"] = data_dict["modals"] * len(data_dict["grid_sizes"])
243
+ data_dict["pixel_values"] = data_dict["pixel_values"].to(torch.bfloat16)
244
+
245
+ with torch.inference_mode():
246
+ output_ids, pred_masks = model.inference(
247
+ input_ids=data_dict["input_ids"].unsqueeze(0).cuda(),
248
+ pixel_values=data_dict["pixel_values"],
249
+ grid_sizes=data_dict["grid_sizes"],
250
+ merge_sizes=data_dict["merge_sizes"],
251
+ modals=data_dict["modals"],
252
+ sam_images=[sam_images],
253
+ sam_size=[sam_size],
254
+ image2maskids=[image2maskids],
255
+ do_sample=do_sample,
256
+ temperature=temperature,
257
+ max_new_tokens=max_new_tokens,
258
+ top_p=top_p,
259
+ top_k=top_k,
260
+ use_cache=True,
261
+ stopping_criteria=[stopping_criteria],
262
+ pad_token_id=tokenizer.eos_token_id,
263
+ seg_start_idx=seg_start_idx
264
+ )
265
+
266
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
267
+ pred_masks_sigmoid = pred_masks.sigmoid()>0.5
268
+
269
+ return outputs, pred_masks_sigmoid
rynnec/constants.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+
9
+ # Image arguments
10
+ IMAGE_TOKEN_INDEX = -200
11
+ DEFAULT_IMAGE_TOKEN = "<image>"
12
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
13
+ DEFAULT_IM_START_TOKEN = "<im_start>"
14
+ DEFAULT_IM_END_TOKEN = "<im_end>"
15
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
16
+
17
+ # Video arguments
18
+ VIDEO_TOKEN_INDEX = -201
19
+ DEFAULT_VIDEO_TOKEN = "<video>"
20
+ NUM_FRAMES = 128
21
+ MAX_FRAMES = 768
22
+ NUM_FRAMES_PER_SECOND = 1
23
+
24
+ # Region arguments
25
+ REGION_TOKEN = "<REGION>"
26
+ REGION_TOKEN_REPLACE = "<region>"
27
+ SEG_TOKEN = "[SEG]"
28
+
29
+ # Audio arguments
30
+ AUDIO_TOKEN_INDEX = -202
31
+ DEFAULT_AUDIO_TOKEN = "<audio>"
32
+
33
+ # Stream arguments
34
+ STREAM_START_TOKEN = "<|stream_start|>"
35
+ STREAM_END_TOKEN = "<|stream_end|>"
36
+ STREAM_MAX_FRAMES = 400
37
+ STREAM_FPS = 2
38
+ STREAM_IMAGE_SIZE = 224
39
+ STREAM_DOWNSAMPLING = 4
40
+
41
+ MODAL_INDEX_MAP = {
42
+ "<image>": -200,
43
+ "<video>": -201,
44
+ "<audio>": -202,
45
+ }
46
+
47
+ subimage_token_num=196
rynnec/mm_utils.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from: https://github.com/DAMO-NLP-SG/VideoLLaMA3.
2
+ # Below is the original copyright:
3
+ # Copyright 2025 The VideoLLaMA3 team, Alibaba Group
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import ast
17
+ import os
18
+ import re
19
+ import math
20
+ import base64
21
+ import traceback
22
+ from io import BytesIO
23
+ from typing import Optional
24
+
25
+ import torch
26
+ import torchvision.transforms.functional as VF
27
+ import torch.nn.functional as F
28
+ import numpy as np
29
+ from transformers import StoppingCriteria
30
+
31
+ import cv2
32
+ import imageio
33
+ import ffmpeg
34
+ from PIL import Image
35
+ from decord import VideoReader, cpu
36
+
37
+ from .constants import NUM_FRAMES, MAX_FRAMES, NUM_FRAMES_PER_SECOND, MODAL_INDEX_MAP, DEFAULT_IMAGE_TOKEN
38
+ from pycocotools import mask as maskUtils
39
+
40
+ from torchvision.transforms.functional import resize, to_pil_image # type: ignore
41
+
42
+
43
+ class DirectResize:
44
+ def __init__(self, target_length: int) -> None:
45
+ self.target_length = target_length
46
+
47
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
48
+ """
49
+ Expects a numpy array with shape HxWxC in uint8 format.
50
+ """
51
+ img = to_pil_image(image, mode='RGB')
52
+ return np.array(img.resize((self.target_length, self.target_length)))
53
+
54
+ def sam_preprocess_batch(x: torch.Tensor) -> torch.Tensor:
55
+ """
56
+ Normalize pixel values and pad to square input for a batch of images.
57
+
58
+ Args:
59
+ images (torch.Tensor): A batch tensor of shape [N, C, H, W].
60
+
61
+ Returns:
62
+ torch.Tensor: A batch tensor with normalized and padded images
63
+ (shape: [N, C, 1024, 1024]).
64
+ """
65
+ pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1)
66
+ pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1)
67
+ img_size = 1024
68
+
69
+ # Normalize colors
70
+ x = (x - pixel_mean) / pixel_std
71
+
72
+ # Pad
73
+ h, w = x.shape[-2:]
74
+ padh = img_size - h
75
+ padw = img_size - w
76
+ x = F.pad(x, (0, padw, 0, padh))
77
+ return x
78
+
79
+
80
+ def sam_preprocess(x: torch.Tensor) -> torch.Tensor:
81
+ """Normalize pixel values and pad to a square input."""
82
+ pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
83
+ pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
84
+ img_size = 1024
85
+
86
+ # Normalize colors
87
+ x = (x - pixel_mean) / pixel_std
88
+
89
+ # Pad
90
+ h, w = x.shape[-2:]
91
+ padh = img_size - h
92
+ padw = img_size - w
93
+ x = F.pad(x, (0, padw, 0, padh))
94
+ return x
95
+
96
+
97
+ def reshape_images_to_raw_grid(mm_features_raw, grid_thws):
98
+ start_idx=0
99
+ reshaped_features = []
100
+ # for thw_group in grid_thws:
101
+ for tensor_thw in grid_thws:
102
+ # for tensor_thw in thw_group:
103
+ t, H, W = tensor_thw.squeeze().tolist()
104
+ num_elements = H * W
105
+ for i in range(t):
106
+ split_tensor = mm_features_raw[start_idx:start_idx + num_elements].view(H, W, -1)
107
+ reshaped_features.append(split_tensor)
108
+
109
+ start_idx += num_elements
110
+
111
+ assert len(mm_features_raw)==start_idx
112
+ return reshaped_features
113
+
114
+ def annToMask(mask_ann, h=None, w=None):
115
+ if isinstance(mask_ann, list):
116
+ rles = maskUtils.frPyObjects(mask_ann, h, w)
117
+ rle = maskUtils.merge(rles)
118
+ elif isinstance(mask_ann['counts'], list):
119
+ # uncompressed RLE
120
+ rle = maskUtils.frPyObjects(mask_ann, h, w)
121
+ else:
122
+ # rle
123
+ rle = mask_ann
124
+ mask = maskUtils.decode(rle)
125
+ return mask
126
+
127
+ def chunk_list(input_list, chunk_size):
128
+ return [input_list[i:i + chunk_size] for i in range(0, len(input_list), chunk_size)]
129
+
130
+
131
+ def load_image_from_base64(image):
132
+ return Image.open(BytesIO(base64.b64decode(image)))
133
+
134
+
135
+ def expand2square(pil_img, background_color):
136
+ width, height = pil_img.size
137
+ if width == height:
138
+ return pil_img
139
+ elif width > height:
140
+ result = Image.new(pil_img.mode, (width, width), background_color)
141
+ result.paste(pil_img, (0, (width - height) // 2))
142
+ return result
143
+ else:
144
+ result = Image.new(pil_img.mode, (height, height), background_color)
145
+ result.paste(pil_img, ((height - width) // 2, 0))
146
+ return result
147
+
148
+
149
+ def grid_divide(image, cell_size):
150
+ """
151
+ Divides an image into grid of a specified size.
152
+
153
+ Args:
154
+ image (PIL.Image.Image): The input image.
155
+ cell_size (int): The size of each cell.
156
+
157
+ Returns:
158
+ list: A list of PIL.Image.Image objects representing the patches.
159
+ """
160
+ grid = []
161
+ width, height = image.size
162
+ for i in range(0, height, cell_size):
163
+ row = []
164
+ for j in range(0, width, cell_size):
165
+ box = (j, i, j + cell_size, i + cell_size)
166
+ row.append(image.crop(box))
167
+ grid.append(row)
168
+
169
+ return grid
170
+
171
+
172
+ def load_images(image_path):
173
+ if isinstance(image_path, str) and os.path.isfile(image_path):
174
+ # images = [cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)]
175
+ images = [Image.open(image_path).convert('RGB')]
176
+ elif isinstance(image_path, str) and os.path.isdir(image_path):
177
+ # images = [cv2.cvtColor(cv2.imread(os.path.join(image_path, f)), cv2.COLOR_BGR2RGB) for f in sorted(os.listdir(image_path))]
178
+ images = [Image.open(os.path.join(image_path, f)).convert('RGB') for f in sorted(os.listdir(image_path))]
179
+ elif isinstance(image_path, list) and isinstance(image_path[0], str):
180
+ # images = [cv2.cvtColor(cv2.imread(f), cv2.COLOR_BGR2RGB) for f in image_path]
181
+ images = [Image.open(f).convert('RGB') for f in image_path]
182
+ elif isinstance(image_path, list) and isinstance(image_path[0], Image.Image):
183
+ images = image_path
184
+ elif isinstance(image_path, Image.Image):
185
+ images = [image_path]
186
+ else:
187
+ raise ValueError(f"Unsupported image path type: {image_path}")
188
+
189
+ return images
190
+
191
+
192
+ def process_pad_image(image, padding_value=(0, 0, 0)):
193
+ image = expand2square(image, padding_value)
194
+
195
+ return [image]
196
+
197
+
198
+ def find_closest_aspect_ratio(src_ratio, tgt_ratios, ori_size, tgt_size):
199
+ best_ratio_diff = float('inf')
200
+ best_ratio = (1, 1)
201
+ area = ori_size[0] * ori_size[1]
202
+ for ratio in tgt_ratios:
203
+ tgt_ratio = ratio[0] / ratio[1]
204
+ ratio_diff = abs(src_ratio - tgt_ratio)
205
+ if ratio_diff < best_ratio_diff:
206
+ best_ratio_diff = ratio_diff
207
+ best_ratio = ratio
208
+ elif ratio_diff == best_ratio_diff:
209
+ if area > 0.5 * tgt_size[0] * tgt_size[1] * ratio[0] * ratio[1]:
210
+ best_ratio = ratio
211
+
212
+ return best_ratio
213
+
214
+
215
+ def process_dynamic_image(image, image_size=384, use_thumbnail=True):
216
+ # Grid Params:
217
+ min_num = 1
218
+ max_num = 12
219
+
220
+ if isinstance(image_size, int):
221
+ image_size = (image_size, image_size)
222
+
223
+ ori_size = image.size
224
+ aspect_ratio = ori_size[0] / ori_size[1]
225
+
226
+ # calculate the existing image aspect ratio
227
+ tgt_ratios = []
228
+ for n in range(min_num, max_num + 1):
229
+ tgt_ratios.extend([(i, j) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num])
230
+ tgt_ratios = set(tgt_ratios)
231
+ tgt_ratios = sorted(tgt_ratios, key=lambda x: x[0] * x[1])
232
+
233
+ # find the closest aspect ratio to the target
234
+ tgt_ratio = find_closest_aspect_ratio(aspect_ratio, tgt_ratios, ori_size, image_size)
235
+
236
+ # resize the image to the target size
237
+ tgt_width = image_size[0] * tgt_ratio[0]
238
+ tgt_height = image_size[1] * tgt_ratio[1]
239
+ resized_img = image.resize((tgt_width, tgt_height))
240
+
241
+ # NOTE: internvl2 style split the image into one column grids
242
+ # num_grids = tgt_ratio[0] * tgt_ratio[1]
243
+ # grid_images = []
244
+ # for i in range(num_grids):
245
+ # box = (
246
+ # (i % tgt_ratio[0]) * image_size[0],
247
+ # (i // tgt_ratio[0]) * image_size[1],
248
+ # (i % tgt_ratio[0] + 1) * image_size[0],
249
+ # (i // tgt_ratio[0] + 1) * image_size[1],
250
+ # )
251
+ # # crop out the grid image
252
+ # grid_images.append(resized_img.crop(box))
253
+ # assert len(grid_images) == num_grids
254
+ # grid_images = [grid_images]
255
+
256
+ # NOTE: eager implementation
257
+ # num_grids = tgt_ratio[0] * tgt_ratio[1]
258
+ # sub_grid_images = []
259
+ # tmp_grid_images = []
260
+ # for i in range(num_grids):
261
+ # box = (
262
+ # (i % tgt_ratio[0]) * image_size[0],
263
+ # (i // tgt_ratio[0]) * image_size[1],
264
+ # (i % tgt_ratio[0] + 1) * image_size[0],
265
+ # (i // tgt_ratio[0] + 1) * image_size[1],
266
+ # )
267
+ # tmp_grid_images.append(resized_img.crop(box))
268
+
269
+ # if (i + 1) % tgt_ratio[0] == 0:
270
+ # sub_grid_images.append(tmp_grid_images)
271
+ # tmp_grid_images = []
272
+
273
+ image_grid = grid_divide(resized_img, image_size[0])
274
+
275
+ if use_thumbnail:
276
+ thumbnail_img = image.resize((image_size[0], image_size[1]))
277
+ image_grid = [[thumbnail_img]] + image_grid
278
+
279
+ return image_grid
280
+
281
+
282
+ def process_highres_image(image_path, image_size=384, use_thumbnail=True, padding_value=(0, 0, 0)):
283
+ # Grid Params:
284
+ grid_width = [1, 2, 3]
285
+ grid_width_real = [x * image_size for x in grid_width]
286
+
287
+ longest_side = max(image.size)
288
+ fit_grid_width_real = [x for x in grid_width_real if x >= longest_side]
289
+ if len(fit_grid_width_real) == 0:
290
+ select_size = max(grid_width_real)
291
+ else:
292
+ select_size = min(fit_grid_width_real)
293
+
294
+ image_padded = expand2square(image, padding_value)
295
+ image_padded = image_padded.resize((select_size, select_size))
296
+ image_grid = grid_divide(image_padded, image_size)
297
+
298
+ if use_thumbnail:
299
+ thumbnail_img = image.resize((image_size, image_size))
300
+ image_grid = [[thumbnail_img]] + image_grid
301
+
302
+ return image_grid
303
+
304
+
305
+ def select_best_resolution(original_size, possible_resolutions):
306
+ """
307
+ Selects the best resolution from a list of possible resolutions based on the original size.
308
+
309
+ Args:
310
+ original_size (tuple): The original size of the image in the format (width, height).
311
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
312
+
313
+ Returns:
314
+ tuple: The best fit resolution in the format (width, height).
315
+ """
316
+ original_width, original_height = original_size
317
+ best_fit = None
318
+ max_effective_resolution = 0
319
+ min_wasted_resolution = float('inf')
320
+
321
+ for width, height in possible_resolutions:
322
+ scale = min(width / original_width, height / original_height)
323
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
324
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
325
+ wasted_resolution = (width * height) - effective_resolution
326
+
327
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
328
+ max_effective_resolution = effective_resolution
329
+ min_wasted_resolution = wasted_resolution
330
+ best_fit = (width, height)
331
+
332
+ return best_fit
333
+
334
+
335
+ def process_anyres_image(image, image_size=384, use_thumbnail=True, padding_value=(0, 0, 0)):
336
+ """
337
+ Process an image with variable resolutions.
338
+
339
+ Args:
340
+ image (PIL.Image.Image): The input image to be processed.
341
+ processor: The image processor object.
342
+
343
+ Returns:
344
+ torch.Tensor: A tensor containing the processed image patches.
345
+ """
346
+ # Grid Params:
347
+ possible_grids = [(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3)]
348
+ possible_resolutions = [(x * image_size, y * image_size) for x, y in possible_grids]
349
+
350
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
351
+
352
+ # resize and padding image
353
+ nw, nh = best_resolution
354
+ ow, oh = image.size
355
+
356
+ scale_factor = min(nw / ow, nh / oh)
357
+ new_size = (int(ow * scale_factor), int(oh * scale_factor))
358
+
359
+ image_padded = Image.new("RGB", (nw, nh), padding_value)
360
+ image_padded.paste(image.resize(new_size), ((nw - new_size[0]) // 2, (nh - new_size[1]) // 2))
361
+
362
+ image_grid = grid_divide(image_padded, image_size)
363
+
364
+ if use_thumbnail:
365
+ thumbnail_img = image.resize((image_size, image_size))
366
+ image_grid = [[thumbnail_img]] + image_grid
367
+
368
+ return image_grid
369
+
370
+
371
+ def process_adares_image(image_path, image_size=384, use_thumbnail=True):
372
+ # Grid Params:
373
+ min_num = 1
374
+ max_num = 12
375
+
376
+ if isinstance(image_size, int):
377
+ image_size = (image_size, image_size)
378
+
379
+ ori_size = image.size
380
+ aspect_ratio = ori_size[0] / ori_size[1]
381
+
382
+ # calculate the existing image aspect ratio
383
+ tgt_ratios = []
384
+ for n in range(min_num, max_num + 1):
385
+ tgt_ratios.extend([(i, j) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num])
386
+ tgt_ratios = set(tgt_ratios)
387
+ possible_resolutions = [(x * image_size[0], y * image_size[1]) for x, y in tgt_ratios]
388
+
389
+ # find the most possible resolution
390
+ best_resolution = select_best_resolution(ori_size, possible_resolutions)
391
+
392
+ # resize the image to the target size
393
+ resized_img = image.resize((best_resolution[0], best_resolution[1]))
394
+
395
+ image_grid = grid_divide(resized_img, image_size[0])
396
+
397
+ if use_thumbnail:
398
+ thumbnail_img = image.resize((image_size[0], image_size[1]))
399
+ image_grid = [[thumbnail_img]] + image_grid
400
+
401
+ return image_grid
402
+
403
+
404
+ def process_images(image_path, processor, aspect_ratio='pad', image_size=384, use_thumbnail=True):
405
+ images = load_images(image_path)
406
+
407
+ padding_value = tuple(int(x*255) for x in processor.image_mean)
408
+
409
+ image_grids = []
410
+ for image in images:
411
+ if aspect_ratio == 'pad':
412
+ image_grid = process_pad_image(image, padding_value=padding_value)
413
+ elif aspect_ratio == 'dynamic':
414
+ image_grid = process_dynamic_image(image, image_size=image_size, use_thumbnail=use_thumbnail)
415
+ elif aspect_ratio == 'highres':
416
+ image_grid = process_highres_image(image, image_size=image_size, use_thumbnail=use_thumbnail, padding_value=padding_value)
417
+ elif aspect_ratio == 'anyres':
418
+ image_grid = process_anyres_image(image, image_size=image_size, use_thumbnail=use_thumbnail, padding_value=padding_value)
419
+ elif aspect_ratio == 'adares':
420
+ image_grid = process_adares_image(image, image_size=image_size, use_thumbnail=use_thumbnail)
421
+ else:
422
+ image_grid = [image]
423
+
424
+ image_grid = [processor.preprocess(image_row, return_tensors='pt', num_images=len(images)) for image_row in image_grid]
425
+ image_grids.append(image_grid)
426
+
427
+ return image_grids
428
+
429
+
430
+ def frame_sample(duration, mode='uniform', num_frames=None, vid_fps=None, fps=None, must_sample_frames=None):
431
+ mask_ids = []
432
+ if mode == 'uniform':
433
+ assert num_frames is not None, "Number of frames must be provided for uniform sampling."
434
+ if duration <= num_frames:
435
+ video_ids = np.arange(duration).astype(int)
436
+ video_ids_list = video_ids.tolist()
437
+ for msf in must_sample_frames:
438
+ if msf not in video_ids_list:
439
+ video_ids_list.append(msf)
440
+ video_ids_list.sort()
441
+ for msf in must_sample_frames:
442
+ mask_ids.append(video_ids_list.index(msf))
443
+ return np.array(video_ids_list), mask_ids
444
+ video_ids = np.linspace(0, duration-1, num_frames, dtype=int)
445
+ video_ids_list = video_ids.tolist()
446
+ if must_sample_frames is not None:
447
+ for msf in must_sample_frames:
448
+ if msf not in video_ids_list:
449
+ video_ids_list.append(msf)
450
+ video_ids_list.sort()
451
+ for msf in must_sample_frames:
452
+ mask_ids.append(video_ids_list.index(msf))
453
+ return np.array(video_ids_list), mask_ids
454
+ elif mode == 'fps':
455
+ assert vid_fps is not None, "FPS must be provided for FPS sampling."
456
+ fps = fps if fps is not None else NUM_FRAMES_PER_SECOND
457
+ segment_len = min(vid_fps // fps, duration)
458
+ video_ids = np.arange(segment_len // 2, duration, segment_len, dtype=int)
459
+ video_ids_list = video_ids.tolist()
460
+ if must_sample_frames is not None:
461
+ for msf in must_sample_frames:
462
+ if msf not in video_ids_list:
463
+ video_ids_list.append(msf)
464
+ video_ids_list.sort()
465
+ for msf in must_sample_frames:
466
+ mask_ids.append(video_ids_list.index(msf))
467
+ return np.array(video_ids_list), mask_ids
468
+
469
+ else:
470
+ raise ImportError(f'Unsupported frame sampling mode: {mode}')
471
+
472
+
473
+ def load_video_from_ids(video_path, s=None, e=None, fps=None, max_frames=None, temporal_factor=1, must_sample_frames=None):
474
+ if s is not None and e is not None:
475
+ s = s if s >= 0. else 0.
476
+ e = e if e >= 0. else 0.
477
+ if s > e:
478
+ s, e = e, s
479
+ elif s == e:
480
+ e = s + 1
481
+
482
+ # 1. Loading Video
483
+ if os.path.isdir(video_path):
484
+ frame_files = sorted(os.listdir(video_path))
485
+
486
+ vid_fps = 3
487
+ num_frames_of_video = len(frame_files)
488
+ elif video_path.endswith('.gif'):
489
+ gif_reader = imageio.get_reader(video_path)
490
+
491
+ vid_fps = 25
492
+ num_frames_of_video = len(gif_reader)
493
+ else:
494
+ vreader = VideoReader(video_path, ctx=cpu(0), num_threads=2)
495
+ # vreader = VideoReader(video_path, ctx=cpu(0), num_threads=1)
496
+
497
+ vid_fps = vreader.get_avg_fps()
498
+ num_frames_of_video = len(vreader)
499
+
500
+ # 2. Determine frame range & Calculate frame indices
501
+ f_start = 0 if s is None else max(int(s * vid_fps) - 1, 0)
502
+ f_end = num_frames_of_video - 1 if e is None else min(int(e * vid_fps) - 1, num_frames_of_video - 1)
503
+ frame_indices = list(range(f_start, f_end + 1))
504
+
505
+ duration = len(frame_indices)
506
+ # 3. Sampling frame indices
507
+ max_frames = max_frames if max_frames is not None else MAX_FRAMES
508
+ if fps is not None and duration / vid_fps < max_frames:
509
+ sampled_ids, mask_ids = frame_sample(duration, mode='fps', vid_fps=vid_fps, fps=fps, must_sample_frames=must_sample_frames)
510
+ sampled_frame_indices = [frame_indices[i] for i in sampled_ids]
511
+ else:
512
+ sampled_ids, mask_ids = frame_sample(duration, mode='uniform', num_frames=max_frames, must_sample_frames=must_sample_frames)
513
+ sampled_frame_indices = [frame_indices[i] for i in sampled_ids]
514
+
515
+ # 4. Acquire frame data
516
+ if os.path.isdir(video_path):
517
+ frames = [cv2.cvtColor(cv2.imread(os.path.join(video_path, frame_files[frame_idx])), cv2.COLOR_BGR2RGB) for frame_idx in sampled_frame_indices]
518
+ elif video_path.endswith('.gif'):
519
+ frames = [cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) for idx, frame in enumerate(gif_reader) if idx in sampled_frame_indices]
520
+ else:
521
+ frames = vreader.get_batch(sampled_frame_indices).asnumpy()
522
+
523
+ # frames = frames.transpose(0, 3, 1, 2)
524
+ timestamps = [x / vid_fps for x in sampled_frame_indices]
525
+
526
+ if temporal_factor > 1:
527
+ pad_length = temporal_factor - len(frames) % temporal_factor
528
+ frames = np.concatenate([frames, frames[-1:].repeat(pad_length, axis=0)])
529
+ [timestamps.append(timestamps[-1] + 1 / fps) for _ in range(pad_length)]
530
+
531
+ # NOTE: pad the video with black frames
532
+ # while num_frames is not None and len(video_data) < num_frames:
533
+ # video_data.append(Image.fromarray(np.zeros((*video_data[-1].size, 3), dtype=np.uint8)))
534
+
535
+ return frames, timestamps, mask_ids
536
+
537
+
538
+ def load_video(
539
+ video_path: str,
540
+ start_time: Optional[float] = None,
541
+ end_time: Optional[float] = None,
542
+ fps: Optional[float] = None,
543
+ max_frames: Optional[float] = None,
544
+ size: Optional[int] = None,
545
+ size_divisible: int = 1,
546
+ precise_time: bool = False,
547
+ verbose: bool = False,
548
+ temporal_factor: int = 1
549
+ ):
550
+ """
551
+ Load and process a video file and return the frames and the timestamps of each frame.
552
+
553
+ Args:
554
+ video_path (str): Path to the video file.
555
+ start_time (float, optional): Start time in seconds. Defaults to None.
556
+ end_time (float, optional): End time in seconds. Defaults to None.
557
+ fps (float, optional): Frames per second. Defaults to None.
558
+ num_frames (float, optional): Number of frames to sample. Defaults to None.
559
+ size (int, optional): Size of the shortest side. Defaults to None.
560
+ size_divisible (int, optional): Size divisible by this number. Defaults to 1.
561
+ precise_time (bool, optional): Whether to use precise time. Defaults to False.
562
+ verbose (bool, optional): Print ffmpeg output. Defaults to False.
563
+
564
+ Returns:
565
+ frames (List[PIL.Image]): List of frames.
566
+ timestamps (List[float]): List of timestamps.
567
+ """
568
+ if start_time is not None and end_time is not None and end_time - start_time < 1:
569
+ return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames)
570
+ if os.path.isdir(video_path):
571
+ return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames)
572
+ if video_path.endswith('.gif'):
573
+ return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames)
574
+ probe = ffmpeg.probe(video_path)
575
+ duration = float(probe['format']['duration'])
576
+ video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
577
+ w, h = int(video_stream['width']), int(video_stream['height'])
578
+
579
+ kwargs, input_kwargs, output_kwargs = {}, {}, {}
580
+ do_trim = start_time is not None or end_time is not None
581
+ if start_time is not None:
582
+ new_start_time = max(float(video_stream['start_time']), start_time)
583
+ duration -= new_start_time - start_time
584
+ start_time = new_start_time
585
+ else:
586
+ start_time = float(video_stream['start_time'])
587
+ if end_time is not None:
588
+ duration = min(duration, end_time - start_time)
589
+ else:
590
+ duration = duration
591
+ if do_trim:
592
+ kwargs = {'ss': start_time, 't': duration}
593
+ if precise_time:
594
+ output_kwargs.update(kwargs)
595
+ else:
596
+ input_kwargs.update(kwargs)
597
+
598
+ if size is not None:
599
+ scale_factor = size / min(w, h)
600
+ new_w, new_h = round(w * scale_factor), round(h * scale_factor)
601
+ else:
602
+ new_w, new_h = w, h
603
+ new_w = new_w // size_divisible * size_divisible
604
+ new_h = new_h // size_divisible * size_divisible
605
+
606
+ # NOTE: It may result in unexpected number of frames in ffmpeg
607
+ # if calculate the fps directly according to max_frames
608
+ # NOTE: the below lines may hurt the performance
609
+ # if max_frames is not None and (fps is None or duration * fps > 2 * max_frames):
610
+ # fps = max_frames / duration * 2
611
+
612
+ stream = ffmpeg.input(video_path, **input_kwargs)
613
+ if fps is not None:
614
+ stream = ffmpeg.filter(stream, "fps", fps=fps, round="down")
615
+ if new_w != w or new_h != h:
616
+ stream = ffmpeg.filter(stream, 'scale', new_w, new_h)
617
+ stream = ffmpeg.output(stream, "pipe:", format="rawvideo", pix_fmt="rgb24", **output_kwargs)
618
+ out, _ = ffmpeg.run(stream, capture_stdout=True, quiet=not verbose)
619
+
620
+ frames = np.frombuffer(out, np.uint8).reshape([-1, new_h, new_w, 3]).transpose([0, 3, 1, 2])
621
+
622
+ if fps is not None:
623
+ timestamps = np.arange(start_time, start_time + duration + 1 / fps, 1 / fps)[:len(frames)]
624
+ else:
625
+ timestamps = np.linspace(start_time, start_time + duration, len(frames))
626
+
627
+ max_frames = max_frames if max_frames is not None else MAX_FRAMES
628
+ if max_frames is not None and len(frames) > max_frames:
629
+ indices = np.linspace(0, len(frames) - 1, max_frames, dtype=int)
630
+ frames = frames[indices]
631
+ timestamps = [timestamps[i] for i in indices]
632
+
633
+ if temporal_factor > 1:
634
+ pad_length = temporal_factor - len(frames) % temporal_factor
635
+ frames = np.concatenate([frames, frames[-1:].repeat(pad_length, axis=0)])
636
+ [timestamps.append(timestamps[-1] + 1 / fps) for _ in range(pad_length)]
637
+
638
+ frames = [frame for frame in frames]
639
+
640
+ return frames, timestamps
641
+
642
+
643
+ def process_video(video_path, processor, s=None, e=None, aspect_ratio='pad', num_frames=None):
644
+ fps = 1 if num_frames is None else None
645
+ # FFmpeg
646
+ frames, timestamps = load_video(video_path, s, e, fps=fps, max_frames=num_frames)
647
+ # Decord
648
+ # frames, timestamps = load_video_from_ids(video_path, s, e, fps=fps, max_frames=num_frames)
649
+
650
+ assert len(frames) == len(timestamps), "Number of frames and timestamps must match."
651
+
652
+ if aspect_ratio == 'pad':
653
+ frames = [expand2square(f, tuple(int(x*255) for x in processor.image_mean)) for f in frames]
654
+
655
+ if aspect_ratio == 'avt':
656
+ frames = [processor.preprocess(frame, return_tensors='pt', image_num=len(frames)) for frame in frames]
657
+ grid_frames = [frames]
658
+ else:
659
+ frames = processor.preprocess(frames, return_tensors='pt', image_num=len(frames))
660
+ grid_frames = [[frames]]
661
+
662
+ return grid_frames, timestamps
663
+
664
+
665
+ def tokenizer_multimodal_token(prompt, tokenizer, multimodal_token=DEFAULT_IMAGE_TOKEN, return_tensors=None):
666
+ """Tokenize text and multimodal tag to input_ids.
667
+
668
+ Args:
669
+ prompt (str): Text prompt (w/ multimodal tag), e.g., '<video>\nDescribe the video.'
670
+ tokenizer (transformers.PreTrainedTokenizer): Tokenizer object.
671
+ multimodal_token (int): Token index corresponding to the multimodal tag.
672
+ """
673
+ multimodal_token_index = MODAL_INDEX_MAP.get(multimodal_token, None)
674
+ if multimodal_token_index is None:
675
+ input_ids = tokenizer(prompt, add_special_tokens=False).input_ids
676
+ else:
677
+ prompt_chunks = [tokenizer(chunk, add_special_tokens=False).input_ids for idx, chunk in enumerate(prompt.split(multimodal_token))]
678
+
679
+ input_ids = []
680
+ for i in range(1, 2 * len(prompt_chunks)):
681
+ if i % 2 == 1:
682
+ input_ids.extend(prompt_chunks[i // 2])
683
+ else:
684
+ input_ids.append(multimodal_token_index)
685
+
686
+ if return_tensors is not None:
687
+ if return_tensors == 'pt':
688
+ return torch.tensor(input_ids, dtype=torch.long)
689
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
690
+ return input_ids
691
+
692
+
693
+ def get_model_name_from_path(model_path):
694
+ model_path = model_path.strip("/")
695
+ model_paths = model_path.split("/")
696
+ if model_paths[-1].startswith('checkpoint-'):
697
+ return model_paths[-2] + "_" + model_paths[-1]
698
+ else:
699
+ return model_paths[-1]
700
+
701
+
702
+ class KeywordsStoppingCriteria(StoppingCriteria):
703
+ def __init__(self, keywords, tokenizer, input_ids):
704
+ self.keywords = keywords
705
+ self.keyword_ids = []
706
+ self.max_keyword_len = 0
707
+ for keyword in keywords:
708
+ cur_keyword_ids = tokenizer(keyword).input_ids
709
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
710
+ cur_keyword_ids = cur_keyword_ids[1:]
711
+ if len(cur_keyword_ids) > self.max_keyword_len:
712
+ self.max_keyword_len = len(cur_keyword_ids)
713
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
714
+ self.tokenizer = tokenizer
715
+ self.start_len = input_ids.shape[1]
716
+
717
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
718
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
719
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
720
+ for keyword_id in self.keyword_ids:
721
+ if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
722
+ return True
723
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
724
+ for keyword in self.keywords:
725
+ if keyword in outputs:
726
+ return True
727
+ return False
728
+
729
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
730
+ outputs = []
731
+ for i in range(output_ids.shape[0]):
732
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
733
+ return all(outputs)
rynnec/model/__init__.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
2
+ # Copyright 2023 Haotian Liu
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import os
18
+ import warnings
19
+ import shutil
20
+
21
+ import torch
22
+ from transformers import PretrainedConfig, AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, AutoProcessor
23
+
24
+ from .projector import load_mm_projector
25
+ from .videollama3_encoder import Videollama3VisionEncoderModel, Videollama3VisionEncoderConfig
26
+ from .rynnec_qwen2 import RynnecQwen2ForCausalLM, RynnecQwen2Config, Videollama3Qwen2Processor
27
+
28
+ def apply_liger_kernel_to_rynnec():
29
+ from liger_kernel.transformers import (
30
+ apply_liger_kernel_to_mistral,
31
+ apply_liger_kernel_to_qwen2,
32
+ apply_liger_kernel_to_qwen3,
33
+ apply_liger_kernel_to_qwen3_moe,
34
+ )
35
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb
36
+ from liger_kernel.transformers.layer_norm import LigerLayerNorm
37
+ from .videollama3_encoder import modeling_videollama3_encoder
38
+
39
+ apply_liger_kernel_to_mistral()
40
+ apply_liger_kernel_to_qwen2()
41
+
42
+ modeling_videollama3_encoder.apply_rotary_pos_emb_vision = liger_rotary_pos_emb
43
+ modeling_videollama3_encoder.LayerNorm = LigerLayerNorm
44
+
45
+
46
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", **kwargs):
47
+ if 'token' in kwargs:
48
+ token = kwargs['token']
49
+ else:
50
+ token = None
51
+
52
+ # NOTE: auto device_map by default
53
+ # if want to put model into a single device, you can set device_map={"": "cuda:0"}
54
+ kwargs = {"device_map": device_map, **kwargs}
55
+
56
+ config = AutoConfig.from_pretrained(model_path)
57
+ config._attn_implementation = kwargs.pop('attn_implementation', "flash_attention_2") # default to flash_attention_2
58
+
59
+ torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else kwargs.pop('torch_dtype', torch.float16)
60
+
61
+ if load_8bit:
62
+ kwargs['load_in_8bit'] = True
63
+ elif load_4bit:
64
+ # NOTE: High-version Transformers will report: """ValueError: You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing `quantization_config` argument at the same time."""
65
+ # kwargs['load_in_4bit'] = True
66
+ kwargs['quantization_config'] = BitsAndBytesConfig(
67
+ load_in_4bit=True,
68
+ bnb_4bit_compute_dtype=torch_dtype,
69
+ bnb_4bit_use_double_quant=True,
70
+ bnb_4bit_quant_type='nf4'
71
+ )
72
+ else:
73
+ kwargs['torch_dtype'] = torch_dtype
74
+
75
+ # judge model type
76
+ model_type = config.model_type if hasattr(config, "model_type") else kwargs.pop('model_type', "rynnec_qwen2")
77
+
78
+ # judge pretrain/finetune
79
+ is_alignment = getattr(config, "tune_mm_mlp_adapter", False) or getattr(config, "is_alignment", False)
80
+
81
+ # NOTE: lora/qlora model loading
82
+ if 'lora' in model_name.lower() or 'qlora' in model_name.lower():
83
+ # if True:
84
+ cfg_pretrained = PretrainedConfig.from_pretrained(model_path, token=token)
85
+ # NOTE: AutoConfig will modify `_name_or_path` property to `model_path` if `model_path` is not None.
86
+ # cfg_pretrained = AutoConfig.from_pretrained(model_path, token=token)
87
+ model_base = model_base if model_base is not None else cfg_pretrained._name_or_path
88
+
89
+ # NOTE: remove qlora training quantization config
90
+ if hasattr(cfg_pretrained, 'quantization_config'):
91
+ del cfg_pretrained.quantization_config
92
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, token=token)
93
+ print('Loading RynnEC from base model...')
94
+
95
+ config_raw = AutoConfig.from_pretrained(model_base)
96
+ new_vocab_size = config.vocab_size
97
+ if config.vocab_size!=config_raw.vocab_size:
98
+ config.vocab_size = config_raw.vocab_size
99
+ config.training = False
100
+
101
+ if 'qwen2' in model_base.lower():
102
+ model = RynnecQwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
103
+ else:
104
+ model = RynnecQwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
105
+
106
+ model.config.mask_decoder_model = "./checkpoints/sam2_hiera_large.pt"
107
+
108
+ token_num, tokem_dim = new_vocab_size, model.lm_head.in_features
109
+
110
+ if model.lm_head.weight.shape[0] != token_num:
111
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
112
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
113
+
114
+ print('Loading additional RynnEC weights...')
115
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
116
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
117
+ else:
118
+ # this is probably from HF Hub
119
+ from huggingface_hub import hf_hub_download
120
+ def load_from_hf(repo_id, filename, subfolder=None):
121
+ cache_file = hf_hub_download(
122
+ repo_id=repo_id,
123
+ filename=filename,
124
+ subfolder=subfolder)
125
+ return torch.load(cache_file, map_location='cpu')
126
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
127
+
128
+ # add
129
+ sam2_model = torch.load(model.config.mask_decoder_model, map_location='cpu')['model']
130
+ prefix = "base_model.model.grounding_encoder.sam2_model."
131
+ for param_name in sam2_model.keys():
132
+ new_param_name = prefix + param_name
133
+ if new_param_name not in non_lora_trainables.keys():
134
+ non_lora_trainables[new_param_name] = sam2_model[param_name]
135
+
136
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
137
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
138
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
139
+ model.load_state_dict(non_lora_trainables, strict=False)
140
+
141
+ from peft import PeftModel
142
+ print('Loading LoRA weights...')
143
+ model = PeftModel.from_pretrained(model, model_path)
144
+ print('Merging LoRA weights...')
145
+ model = model.merge_and_unload()
146
+ print('Model is loaded...')
147
+
148
+
149
+ elif model_base is not None or '-base' in model_name.lower() or is_alignment:
150
+ # NOTE: Base/Pretrain model loading
151
+ print('Loading RynnEC from base model...')
152
+ cfg_pretrained = PretrainedConfig.from_pretrained(model_path, token=token)
153
+ # NOTE: AutoConfig will modify `_name_or_path` property to `model_path` if `model_path` is not None.
154
+ # cfg_pretrained = AutoConfig.from_pretrained(model_path, token=token)
155
+ model_base = model_base if model_base is not None else cfg_pretrained._name_or_path
156
+
157
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, token=token)
158
+
159
+ if model_type in ['rynnec', 'rynnec_qwen2']:
160
+ model = RynnecQwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
161
+ else:
162
+ model = RynnecQwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
163
+
164
+ # NOTE; loading vision-language projector
165
+ # * old codes for loading local mm_projector.bin
166
+ # mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
167
+ # mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
168
+ # model.load_state_dict(mm_projector_weights, strict=False)
169
+ # * new codes which supports loading mm_projector.bin both offline and online
170
+ mm_projector_weights = load_mm_projector(model_path, token=token)
171
+ model.load_state_dict(mm_projector_weights, strict=False)
172
+ elif 'rynnec' in model_type:
173
+ # NOTE: SFT model loading
174
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, token=token)
175
+
176
+ if model_type in ['rynnec_qwen2']:
177
+ model = RynnecQwen2ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=config, **kwargs)
178
+ else:
179
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=config, **kwargs)
180
+ else:
181
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, token=token)
182
+ model = AutoModelForCausalLM.from_pretrained(model_path, config=config, **kwargs)
183
+
184
+ processor = None
185
+
186
+ # if "videollama" in model_type:
187
+ if True:
188
+ vision_encoder = model.get_vision_encoder()
189
+ processor = vision_encoder.image_processor
190
+
191
+ if hasattr(model.config, "max_sequence_length"):
192
+ context_len = model.config.max_sequence_length
193
+ else:
194
+ context_len = 2048
195
+
196
+ return tokenizer, model, processor, context_len
rynnec/model/encoder.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import (CLIPImageProcessor, CLIPVisionConfig,
6
+ CLIPVisionModel, SiglipImageProcessor,
7
+ SiglipVisionConfig, SiglipVisionModel)
8
+
9
+ from .videollama3_encoder import (Videollama3VisionEncoderConfig,
10
+ Videollama3VisionEncoderModel, Videollama3ImageProcessor)
11
+
12
+
13
+ class CLIPVisionEncoder(nn.Module):
14
+
15
+ def __init__(self, vision_encoder, args, delay_load=False):
16
+ super().__init__()
17
+
18
+ self.is_loaded = False
19
+
20
+ self.vision_encoder_name = vision_encoder
21
+ self.select_layer = args.mm_vision_select_layer
22
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
23
+
24
+ if not delay_load:
25
+ self.attn_implementation = getattr(args, 'mm_attn_implementation', 'flash_attention_2')
26
+ self.load_model()
27
+ else:
28
+ # uncertain whether flash-attention-2 is supported during inference phase.
29
+ self.attn_implementation = 'sdpa' # 'eager'
30
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_encoder_name)
31
+
32
+ def load_model(self):
33
+ if self.is_loaded:
34
+ print('Vision tower is already loaded, `load model` call again, skipping.')
35
+ return
36
+
37
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_encoder_name)
38
+
39
+ self.vision_encoder = CLIPVisionModel.from_pretrained(self.vision_encoder_name,
40
+ attn_implementation=self.attn_implementation)
41
+
42
+ self.is_loaded = True
43
+
44
+ def feature_select(self, image_forward_outs):
45
+ image_features = image_forward_outs.hidden_states[self.select_layer]
46
+ if self.select_feature == 'patch':
47
+ image_features = image_features[:, 1:]
48
+ elif self.select_feature == 'cls_patch':
49
+ image_features = image_features
50
+ else:
51
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
52
+ return image_features
53
+
54
+ def forward(self, images, **kwargs):
55
+ images = torch.cat(images)
56
+ if type(images) is list:
57
+ image_features = []
58
+ for image in images:
59
+ image_forward_out = self.vision_encoder(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
60
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
61
+ image_features.append(image_feature)
62
+ else:
63
+ image_forward_outs = self.vision_encoder(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
64
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
65
+
66
+ return image_features
67
+
68
+ @property
69
+ def dummy_feature(self):
70
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
71
+
72
+ @property
73
+ def dtype(self):
74
+ return self.vision_encoder.dtype
75
+
76
+ @property
77
+ def device(self):
78
+ return self.vision_encoder.device
79
+
80
+ @property
81
+ def config(self):
82
+ if self.is_loaded:
83
+ return self.vision_encoder.config
84
+ else:
85
+ return self.cfg_only
86
+
87
+ @property
88
+ def hidden_size(self):
89
+ return self.config.hidden_size
90
+
91
+ @property
92
+ def num_patches(self):
93
+ return (self.config.image_size // self.config.patch_size) ** 2
94
+
95
+ @property
96
+ def num_patches_per_side(self):
97
+ return self.config.image_size // self.config.patch_size
98
+
99
+ @property
100
+ def image_size(self):
101
+ return self.config.image_size
102
+
103
+
104
+ class SiglipVisionEncoder(nn.Module):
105
+
106
+ def __init__(self, vision_encoder, args, delay_load=False):
107
+ super().__init__()
108
+
109
+ self.is_loaded = False
110
+
111
+ self.vision_encoder_name = vision_encoder
112
+ self.select_layer = args.mm_vision_select_layer
113
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
114
+
115
+ if not delay_load:
116
+ self.attn_implementation = getattr(args, 'mm_attn_implementation', 'flash_attention_2')
117
+ self.load_model()
118
+ else:
119
+ # uncertain whether flash-attention-2 is supported during inference phase.
120
+ self.attn_implementation = 'sdpa' # 'eager'
121
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_encoder_name)
122
+
123
+ def load_model(self):
124
+ if self.is_loaded:
125
+ print('Vision tower is already loaded, `load model` call again, skipping.')
126
+ return
127
+
128
+ self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_encoder_name)
129
+
130
+ self.vision_encoder = SiglipVisionModel.from_pretrained(self.vision_encoder_name,
131
+ attn_implementation=self.attn_implementation)
132
+
133
+ self.is_loaded = True
134
+
135
+ def feature_select(self, image_forward_outs):
136
+ image_features = image_forward_outs.hidden_states[self.select_layer]
137
+ if self.select_feature == 'patch':
138
+ image_features = image_features
139
+ else:
140
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
141
+ return image_features
142
+
143
+ def forward(self, images, **kwargs):
144
+ images = torch.cat(images)
145
+ if type(images) is list:
146
+ image_features = []
147
+ for image in images:
148
+ image_forward_out = self.vision_encoder(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
149
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
150
+ image_features.append(image_feature)
151
+ else:
152
+ image_forward_outs = self.vision_encoder(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
153
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
154
+
155
+ return image_features
156
+
157
+ @property
158
+ def dummy_feature(self):
159
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
160
+
161
+ @property
162
+ def dtype(self):
163
+ return self.vision_encoder.dtype
164
+
165
+ @property
166
+ def device(self):
167
+ return self.vision_encoder.device
168
+
169
+ @property
170
+ def config(self):
171
+ if self.is_loaded:
172
+ return self.vision_encoder.config
173
+ else:
174
+ return self.cfg_only
175
+
176
+ @property
177
+ def hidden_size(self):
178
+ return self.config.hidden_size
179
+
180
+ @property
181
+ def num_patches(self):
182
+ return (self.config.image_size // self.config.patch_size) ** 2
183
+
184
+ @property
185
+ def num_patches_per_side(self):
186
+ return self.config.image_size // self.config.patch_size
187
+
188
+ @property
189
+ def image_size(self):
190
+ return self.config.image_size
191
+
192
+
193
+ class Videollama3VisionEncoder(nn.Module):
194
+
195
+ def __init__(self, vision_encoder, args, delay_load=False):
196
+ super().__init__()
197
+
198
+ self.is_loaded = False
199
+
200
+ self.vision_encoder_name = vision_encoder
201
+ self.args = args
202
+
203
+ if not delay_load:
204
+ self.attn_implementation = getattr(args, 'mm_attn_implementation', 'flash_attention_2')
205
+ self.load_model(self.args)
206
+ else:
207
+ # uncertain whether flash-attention-2 is supported during inference phase.
208
+ self.attn_implementation = 'sdpa' # 'eager'
209
+ self.cfg_only = Videollama3VisionEncoderConfig.from_pretrained(self.vision_encoder_name)
210
+
211
+ def load_model(self, args):
212
+ if self.is_loaded:
213
+ print('Vision tower is already loaded, `load model` call again, skipping.')
214
+ return
215
+
216
+ # merge_size is set to 1 by default, because STAGE1, STAGE1.5, STAGE2 are trained with merge_size=1
217
+ # for stage 3, the merge_size is set to 2 by argments.
218
+ self.image_processor = Videollama3ImageProcessor.from_pretrained(self.vision_encoder_name)
219
+
220
+ # merge_size is fixed to 1 for STAGE1, STAGE1.5, STAGE2, STAGE3 in encoder and can be modified in connector.
221
+ self.cfg_only = Videollama3VisionEncoderConfig.from_pretrained(self.vision_encoder_name)
222
+
223
+ self.vision_encoder = Videollama3VisionEncoderModel.from_pretrained(
224
+ self.vision_encoder_name,
225
+ torch_dtype=args.torch_dtype,
226
+ attn_implementation=self.attn_implementation)
227
+
228
+ self.is_loaded = True
229
+
230
+ def forward(self, pixel_values, grid_sizes, merge_sizes, **kwargs):
231
+ image_features = self.vision_encoder(pixel_values, grid_sizes, merge_sizes)
232
+ return image_features
233
+
234
+ @property
235
+ def dummy_feature(self):
236
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
237
+
238
+ @property
239
+ def dtype(self):
240
+ return self.vision_encoder.dtype
241
+
242
+ @property
243
+ def device(self):
244
+ return self.vision_encoder.device
245
+
246
+ @property
247
+ def config(self):
248
+ if self.is_loaded:
249
+ return self.vision_encoder.config
250
+ else:
251
+ return self.cfg_only
252
+
253
+ @property
254
+ def hidden_size(self):
255
+ return self.config.hidden_size
256
+
257
+ @property
258
+ def num_patches(self):
259
+ return -1
260
+
261
+ @property
262
+ def num_patches_per_side(self):
263
+ return -1
264
+
265
+ @property
266
+ def image_size(self):
267
+ return -1
268
+
269
+
270
+ def build_vision_encoder(vision_encoder_cfg, **kwargs):
271
+ vision_encoder = getattr(vision_encoder_cfg, 'mm_vision_encoder', getattr(vision_encoder_cfg, 'vision_encoder', None))
272
+
273
+ if 'clip' in vision_encoder:
274
+ vision_encoder = CLIPVisionEncoder(vision_encoder, args=vision_encoder_cfg, **kwargs)
275
+ elif 'navit' in vision_encoder.lower() or 'damovl' in vision_encoder:
276
+ vision_encoder = Videollama3VisionEncoder(vision_encoder, args=vision_encoder_cfg, **kwargs)
277
+ elif 'siglip' in vision_encoder:
278
+ vision_encoder = SiglipVisionEncoder(vision_encoder, args=vision_encoder_cfg, **kwargs)
279
+ else:
280
+ raise ValueError(f'Unknown vision encoder: {vision_encoder}')
281
+
282
+ return vision_encoder
rynnec/model/extension/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sam2_base import SAM2Base
rynnec/model/extension/sam2_base.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/magic-research/Sa2VA/blob/main/projects/llava_sam2/models/extension/sam2_base.py.
2
+ # Below is the original copyright:
3
+ # coding=utf-8
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import torch
18
+ import torch.nn.functional as F
19
+
20
+ from third_parts.sam2.modeling.sam2_base import SAM2Base as _SAM2Base
21
+ from third_parts.sam2.modeling.sam2_base import NO_OBJ_SCORE
22
+
23
+
24
+ class SAM2Base(_SAM2Base):
25
+
26
+ def track_step(
27
+ self,
28
+ frame_idx,
29
+ is_init_cond_frame,
30
+ current_vision_feats,
31
+ current_vision_pos_embeds,
32
+ feat_sizes,
33
+ point_inputs,
34
+ mask_inputs,
35
+ output_dict,
36
+ num_frames,
37
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
38
+ # Whether to run the memory encoder on the predicted masks. Sometimes we might want
39
+ # to skip the memory encoder with `run_mem_encoder=False`. For example,
40
+ # in demo we might call `track_step` multiple times for each user click,
41
+ # and only encode the memory when the user finalizes their clicks. And in ablation
42
+ # settings like SAM training on static images, we don't need the memory encoder.
43
+ run_mem_encoder=True,
44
+ # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
45
+ prev_sam_mask_logits=None,
46
+ ## Extension: LLM prompt
47
+ language_embd=None,
48
+ ):
49
+ current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
50
+ # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
51
+ if len(current_vision_feats) > 1:
52
+ high_res_features = [
53
+ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
54
+ for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
55
+ ]
56
+ else:
57
+ high_res_features = None
58
+ if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
59
+ # When use_mask_input_as_output_without_sam=True, we directly output the mask input
60
+ # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
61
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0)
62
+ pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
63
+ sam_outputs = self._use_mask_as_output(
64
+ pix_feat, high_res_features, mask_inputs
65
+ )
66
+ else:
67
+ # fused the visual feature with previous memory features in the memory bank
68
+ pix_feat_with_mem = self._prepare_memory_conditioned_features(
69
+ frame_idx=frame_idx,
70
+ is_init_cond_frame=is_init_cond_frame,
71
+ current_vision_feats=current_vision_feats[-1:],
72
+ current_vision_pos_embeds=current_vision_pos_embeds[-1:],
73
+ feat_sizes=feat_sizes[-1:],
74
+ output_dict=output_dict,
75
+ num_frames=num_frames,
76
+ track_in_reverse=track_in_reverse,
77
+ )
78
+ # apply SAM-style segmentation head
79
+ # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
80
+ # e.g. in demo where such logits come from earlier interaction instead of correction sampling
81
+ # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
82
+ if prev_sam_mask_logits is not None:
83
+ assert point_inputs is not None and mask_inputs is None
84
+ mask_inputs = prev_sam_mask_logits
85
+ multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
86
+ sam_outputs = self._forward_sam_heads(
87
+ backbone_features=pix_feat_with_mem,
88
+ point_inputs=point_inputs,
89
+ mask_inputs=mask_inputs,
90
+ high_res_features=high_res_features,
91
+ multimask_output=multimask_output,
92
+ # Inject language Embed if possible
93
+ language_embd=language_embd,
94
+ )
95
+ (
96
+ _,
97
+ _,
98
+ _,
99
+ low_res_masks,
100
+ high_res_masks,
101
+ obj_ptr,
102
+ _,
103
+ ) = sam_outputs
104
+
105
+ current_out["pred_masks"] = low_res_masks
106
+ current_out["pred_masks_high_res"] = high_res_masks
107
+ current_out["obj_ptr"] = obj_ptr
108
+
109
+ # Finally run the memory encoder on the predicted mask to encode
110
+ # it into a new memory feature (that can be used in future frames)
111
+ if run_mem_encoder and self.num_maskmem > 0:
112
+ high_res_masks_for_mem_enc = high_res_masks
113
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
114
+ current_vision_feats=current_vision_feats,
115
+ feat_sizes=feat_sizes,
116
+ pred_masks_high_res=high_res_masks_for_mem_enc,
117
+ is_mask_from_pts=(point_inputs is not None),
118
+ )
119
+ current_out["maskmem_features"] = maskmem_features
120
+ current_out["maskmem_pos_enc"] = maskmem_pos_enc
121
+ else:
122
+ current_out["maskmem_features"] = None
123
+ current_out["maskmem_pos_enc"] = None
124
+
125
+ return current_out
126
+
127
+
128
+ def _forward_sam_heads(
129
+ self,
130
+ backbone_features,
131
+ point_inputs=None,
132
+ mask_inputs=None,
133
+ high_res_features=None,
134
+ multimask_output=False,
135
+ ## Extension: LLM prompt
136
+ language_embd=None,
137
+ ):
138
+ """
139
+ Forward SAM prompt encoders and mask heads.
140
+
141
+ Inputs:
142
+ - backbone_features: image features of [B, C, H, W] shape
143
+ - point_inputs: a dictionary with "point_coords" and "point_labels", where
144
+ 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
145
+ absolute pixel-unit coordinate in (x, y) format of the P input points
146
+ 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
147
+ positive clicks, 0 means negative clicks, and -1 means padding
148
+ - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
149
+ same spatial size as the image.
150
+ - high_res_features: either 1) None or 2) or a list of length 2 containing
151
+ two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
152
+ which will be used as high-resolution feature maps for SAM decoder.
153
+ - multimask_output: if it's True, we output 3 candidate masks and their 3
154
+ corresponding IoU estimates, and if it's False, we output only 1 mask and
155
+ its corresponding IoU estimate.
156
+
157
+ Outputs:
158
+ - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
159
+ `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
160
+ output mask logits (before sigmoid) for the low-resolution masks, with 4x
161
+ the resolution (1/4 stride) of the input backbone_features.
162
+ - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
163
+ if `multimask_output=True` and M = 1 if `multimask_output=False`),
164
+ upsampled from the low-resolution masks, with shape size as the image
165
+ (stride is 1 pixel).
166
+ - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
167
+ if `multimask_output=False`), the estimated IoU of each output mask.
168
+ - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
169
+ If `multimask_output=True`, it's the mask with the highest IoU estimate.
170
+ If `multimask_output=False`, it's the same as `low_res_multimasks`.
171
+ - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
172
+ If `multimask_output=True`, it's the mask with the highest IoU estimate.
173
+ If `multimask_output=False`, it's the same as `high_res_multimasks`.
174
+ - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
175
+ based on the output token from the SAM mask decoder.
176
+ """
177
+ B = backbone_features.size(0)
178
+ device = backbone_features.device
179
+ assert backbone_features.size(1) == self.sam_prompt_embed_dim
180
+ assert backbone_features.size(2) == self.sam_image_embedding_size
181
+ assert backbone_features.size(3) == self.sam_image_embedding_size
182
+
183
+ # a) Handle point prompts
184
+ if point_inputs is not None:
185
+ sam_point_coords = point_inputs["point_coords"]
186
+ sam_point_labels = point_inputs["point_labels"]
187
+ assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
188
+ else:
189
+ # If no points are provide, pad with an empty point (with label -1)
190
+ sam_point_coords = torch.zeros(B, 1, 2, device=device)
191
+ sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
192
+
193
+ # b) Handle mask prompts
194
+ if mask_inputs is not None:
195
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
196
+ # and feed it as a dense mask prompt into the SAM mask encoder
197
+ assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
198
+ if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
199
+ sam_mask_prompt = F.interpolate(
200
+ mask_inputs.float(),
201
+ size=self.sam_prompt_encoder.mask_input_size,
202
+ align_corners=False,
203
+ mode="bilinear",
204
+ antialias=True, # use antialias for downsampling
205
+ )
206
+ else:
207
+ sam_mask_prompt = mask_inputs
208
+ else:
209
+ # Otherwise, simply feed None (and SAM's prompt encoder will add
210
+ # a learned `no_mask_embed` to indicate no mask input in this case).
211
+ sam_mask_prompt = None
212
+
213
+ sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
214
+ points=(sam_point_coords, sam_point_labels),
215
+ boxes=None,
216
+ masks=sam_mask_prompt,
217
+ )
218
+
219
+ ## Extension: LLM prompt
220
+ if language_embd is not None:
221
+ # B N C
222
+ # print('sparse_embeddings ', sparse_embeddings.shape, 'language_embd ', language_embd.shape)
223
+ assert sparse_embeddings.size(0) == language_embd.size(0)
224
+ assert sparse_embeddings.size(2) == language_embd.size(2)
225
+ sparse_embeddings = torch.cat([sparse_embeddings, language_embd], dim=1)
226
+
227
+ (
228
+ low_res_multimasks,
229
+ ious,
230
+ sam_output_tokens,
231
+ object_score_logits,
232
+ ) = self.sam_mask_decoder(
233
+ image_embeddings=backbone_features,
234
+ image_pe=self.sam_prompt_encoder.get_dense_pe(),
235
+ sparse_prompt_embeddings=sparse_embeddings,
236
+ dense_prompt_embeddings=dense_embeddings,
237
+ multimask_output=multimask_output,
238
+ repeat_image=False, # the image is already batched
239
+ high_res_features=high_res_features,
240
+ )
241
+ if self.pred_obj_scores:
242
+ is_obj_appearing = object_score_logits > 0
243
+
244
+ # Mask used for spatial memories is always a *hard* choice between obj and no obj,
245
+ # consistent with the actual mask prediction
246
+ # print('Do torch.where !!!')
247
+ # low_res_multimasks = torch.where(
248
+ # is_obj_appearing[:, None, None],
249
+ # low_res_multimasks,
250
+ # NO_OBJ_SCORE,
251
+ # )
252
+
253
+ # convert masks from possibly bfloat16 (or float16) to float32
254
+ # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
255
+ low_res_multimasks = low_res_multimasks.float()
256
+ high_res_multimasks = F.interpolate(
257
+ low_res_multimasks,
258
+ size=(self.image_size, self.image_size),
259
+ mode="bilinear",
260
+ align_corners=False,
261
+ )
262
+
263
+ sam_output_token = sam_output_tokens[:, 0]
264
+ if multimask_output:
265
+ # take the best mask prediction (with the highest IoU estimation)
266
+ best_iou_inds = torch.argmax(ious, dim=-1)
267
+ batch_inds = torch.arange(B, device=device)
268
+ low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
269
+ high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
270
+ if sam_output_tokens.size(1) > 1:
271
+ sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
272
+ else:
273
+ low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
274
+
275
+ # Extract object pointer from the SAM output token (with occlusion handling)
276
+ obj_ptr = self.obj_ptr_proj(sam_output_token)
277
+ if self.pred_obj_scores:
278
+ # Allow *soft* no obj ptr, unlike for masks
279
+ if self.soft_no_obj_ptr:
280
+ # Only hard possible with gt
281
+ assert not self.teacher_force_obj_scores_for_mem
282
+ lambda_is_obj_appearing = object_score_logits.sigmoid()
283
+ else:
284
+ lambda_is_obj_appearing = is_obj_appearing.float()
285
+
286
+ if self.fixed_no_obj_ptr:
287
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
288
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
289
+
290
+ return (
291
+ low_res_multimasks,
292
+ high_res_multimasks,
293
+ ious,
294
+ low_res_masks,
295
+ high_res_masks,
296
+ obj_ptr,
297
+ object_score_logits,
298
+ )
rynnec/model/loss.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/magic-research/Sa2VA.
2
+ # Below is the original copyright:
3
+ # coding=utf-8
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ linear_cross_entropy = None
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ import torch.nn as nn
21
+ from rynnec.constants import IGNORE_INDEX
22
+ from torch import Tensor
23
+ import logging
24
+ from huggingface_hub import hf_hub_download
25
+ import functools
26
+ from typing import Callable, Optional
27
+
28
+
29
+ def reduce_loss(loss: Tensor, reduction: str) -> Tensor:
30
+ """Reduce loss as specified.
31
+
32
+ Args:
33
+ loss (Tensor): Elementwise loss tensor.
34
+ reduction (str): Options are "none", "mean" and "sum".
35
+
36
+ Return:
37
+ Tensor: Reduced loss tensor.
38
+ """
39
+ reduction_enum = F._Reduction.get_enum(reduction)
40
+ # none: 0, elementwise_mean:1, sum: 2
41
+ if reduction_enum == 0:
42
+ return loss
43
+ elif reduction_enum == 1:
44
+ return loss.mean()
45
+ elif reduction_enum == 2:
46
+ return loss.sum()
47
+
48
+
49
+ def weight_reduce_loss(loss: Tensor,
50
+ weight: Optional[Tensor] = None,
51
+ reduction: str = 'mean',
52
+ avg_factor: Optional[float] = None) -> Tensor:
53
+ """Apply element-wise weight and reduce loss.
54
+
55
+ Args:
56
+ loss (Tensor): Element-wise loss.
57
+ weight (Optional[Tensor], optional): Element-wise weights.
58
+ Defaults to None.
59
+ reduction (str, optional): Same as built-in losses of PyTorch.
60
+ Defaults to 'mean'.
61
+ avg_factor (Optional[float], optional): Average factor when
62
+ computing the mean of losses. Defaults to None.
63
+
64
+ Returns:
65
+ Tensor: Processed loss values.
66
+ """
67
+ # if weight is specified, apply element-wise weight
68
+ if weight is not None:
69
+ loss = loss * weight
70
+
71
+ # if avg_factor is not specified, just reduce the loss
72
+ if avg_factor is None:
73
+ loss = reduce_loss(loss, reduction)
74
+ else:
75
+ # if reduction is mean, then average the loss by avg_factor
76
+ if reduction == 'mean':
77
+ # Avoid causing ZeroDivisionError when avg_factor is 0.0,
78
+ # i.e., all labels of an image belong to ignore index.
79
+ eps = torch.finfo(torch.float32).eps
80
+ loss = loss.sum() / (avg_factor + eps)
81
+ # if reduction is 'none', then do nothing, otherwise raise an error
82
+ elif reduction != 'none':
83
+ raise ValueError('avg_factor can not be used with reduction="sum"')
84
+ return loss
85
+
86
+
87
+ def dice_loss(pred,
88
+ target,
89
+ weight=None,
90
+ eps=1e-3,
91
+ reduction='mean',
92
+ naive_dice=False,
93
+ avg_factor=None):
94
+ """Calculate dice loss, there are two forms of dice loss is supported:
95
+
96
+ - the one proposed in `V-Net: Fully Convolutional Neural
97
+ Networks for Volumetric Medical Image Segmentation
98
+ <https://arxiv.org/abs/1606.04797>`_.
99
+ - the dice loss in which the power of the number in the
100
+ denominator is the first power instead of the second
101
+ power.
102
+
103
+ Args:
104
+ pred (torch.Tensor): The prediction, has a shape (n, *)
105
+ target (torch.Tensor): The learning label of the prediction,
106
+ shape (n, *), same shape of pred.
107
+ weight (torch.Tensor, optional): The weight of loss for each
108
+ prediction, has a shape (n,). Defaults to None.
109
+ eps (float): Avoid dividing by zero. Default: 1e-3.
110
+ reduction (str, optional): The method used to reduce the loss into
111
+ a scalar. Defaults to 'mean'.
112
+ Options are "none", "mean" and "sum".
113
+ naive_dice (bool, optional): If false, use the dice
114
+ loss defined in the V-Net paper, otherwise, use the
115
+ naive dice loss in which the power of the number in the
116
+ denominator is the first power instead of the second
117
+ power.Defaults to False.
118
+ avg_factor (int, optional): Average factor that is used to average
119
+ the loss. Defaults to None.
120
+ """
121
+
122
+ input = pred.flatten(1)
123
+ target = target.flatten(1).float()
124
+
125
+ a = torch.sum(input * target, 1)
126
+ if naive_dice:
127
+ b = torch.sum(input, 1)
128
+ c = torch.sum(target, 1)
129
+ d = (2 * a + eps) / (b + c + eps)
130
+ else:
131
+ b = torch.sum(input * input, 1) + eps
132
+ c = torch.sum(target * target, 1) + eps
133
+ d = (2 * a) / (b + c)
134
+
135
+ loss = 1 - d
136
+ if weight is not None:
137
+ assert weight.ndim == loss.ndim
138
+ assert len(weight) == len(pred)
139
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
140
+ return loss
141
+
142
+
143
+
144
+ class DiceLoss(nn.Module):
145
+
146
+ def __init__(self,
147
+ use_sigmoid=True,
148
+ activate=True,
149
+ reduction='mean',
150
+ naive_dice=False,
151
+ loss_weight=1.0,
152
+ eps=1e-3):
153
+ """Compute dice loss.
154
+
155
+ Args:
156
+ use_sigmoid (bool, optional): Whether to the prediction is
157
+ used for sigmoid or softmax. Defaults to True.
158
+ activate (bool): Whether to activate the predictions inside,
159
+ this will disable the inside sigmoid operation.
160
+ Defaults to True.
161
+ reduction (str, optional): The method used
162
+ to reduce the loss. Options are "none",
163
+ "mean" and "sum". Defaults to 'mean'.
164
+ naive_dice (bool, optional): If false, use the dice
165
+ loss defined in the V-Net paper, otherwise, use the
166
+ naive dice loss in which the power of the number in the
167
+ denominator is the first power instead of the second
168
+ power. Defaults to False.
169
+ loss_weight (float, optional): Weight of loss. Defaults to 1.0.
170
+ eps (float): Avoid dividing by zero. Defaults to 1e-3.
171
+ """
172
+
173
+ super(DiceLoss, self).__init__()
174
+ self.use_sigmoid = use_sigmoid
175
+ self.reduction = reduction
176
+ self.naive_dice = naive_dice
177
+ self.loss_weight = loss_weight
178
+ self.eps = eps
179
+ self.activate = activate
180
+
181
+ def forward(self,
182
+ pred,
183
+ target,
184
+ weight=None,
185
+ reduction_override=None,
186
+ avg_factor=None):
187
+ """Forward function.
188
+
189
+ Args:
190
+ pred (torch.Tensor): The prediction, has a shape (n, *).
191
+ target (torch.Tensor): The label of the prediction,
192
+ shape (n, *), same shape of pred.
193
+ weight (torch.Tensor, optional): The weight of loss for each
194
+ prediction, has a shape (n,). Defaults to None.
195
+ avg_factor (int, optional): Average factor that is used to average
196
+ the loss. Defaults to None.
197
+ reduction_override (str, optional): The reduction method used to
198
+ override the original reduction method of the loss.
199
+ Options are "none", "mean" and "sum".
200
+
201
+ Returns:
202
+ torch.Tensor: The calculated loss
203
+ """
204
+
205
+ assert reduction_override in (None, 'none', 'mean', 'sum')
206
+ reduction = (
207
+ reduction_override if reduction_override else self.reduction)
208
+
209
+ if self.activate:
210
+ if self.use_sigmoid:
211
+ pred = pred.sigmoid()
212
+ else:
213
+ raise NotImplementedError
214
+
215
+ loss = self.loss_weight * dice_loss(
216
+ pred,
217
+ target,
218
+ weight,
219
+ eps=self.eps,
220
+ reduction=reduction,
221
+ naive_dice=self.naive_dice,
222
+ avg_factor=avg_factor)
223
+
224
+ return loss
225
+
226
+
227
+ def cross_entropy_loss(
228
+ hidden_states,
229
+ lm_head,
230
+ position_ids,
231
+ labels,
232
+ reduction_scope="sequence",
233
+ **loss_kwargs
234
+ ):
235
+ batch_size = hidden_states.size(0)
236
+
237
+ shift_hidden_states = hidden_states[..., :-1, :]
238
+ shift_labels = labels[..., 1:]
239
+ mask = shift_labels != IGNORE_INDEX
240
+ shift_hidden_states = shift_hidden_states[mask].contiguous()
241
+ shift_labels = shift_labels[mask].contiguous()
242
+
243
+ if mask.sum() == 0:
244
+ print(f"Get labels={labels}. Found no sample to calculate loss!")
245
+ pseudo_logits = lm_head(hidden_states[:, 0:1])
246
+ loss = 0.0 * pseudo_logits.mean()
247
+ return loss
248
+
249
+ if "num_items_in_batch" not in loss_kwargs:
250
+ reduction = "mean"
251
+ denominator = None
252
+
253
+ elif reduction_scope == "batch":
254
+ reduction = "sum"
255
+ denominator = loss_kwargs["num_items_in_batch"]
256
+
257
+ elif reduction_scope == "sequence":
258
+ reduction = "none"
259
+
260
+ if batch_size == 1:
261
+ # NOTE: packed sequence
262
+ start_indices = torch.nonzero(position_ids[0] == 0)[:, 0]
263
+ end_indices = F.pad(start_indices[1:], (0, 1), value=position_ids.size(1))
264
+ batch_indices = torch.cat(
265
+ [
266
+ torch.full((e - s,), fill_value=i, device=position_ids.device, dtype=torch.long)
267
+ for i, (s, e) in enumerate(zip(start_indices, end_indices))
268
+ ],
269
+ ).unsqueeze(0)
270
+ else:
271
+ batch_indices = torch.arange(batch_size, device=position_ids.device)
272
+ batch_indices = batch_indices.unsqueeze(1).expand(-1, hidden_states.size(1))
273
+
274
+ shift_batch_indices = batch_indices[..., :-1]
275
+ shift_batch_indices = shift_batch_indices[mask].contiguous()
276
+ num_tokens = F.one_hot(shift_batch_indices).sum(dim=0)
277
+ denominator = num_tokens[shift_batch_indices] * loss_kwargs["num_items_in_batch"]
278
+
279
+ else:
280
+ raise ValueError(f"Unknown reduction scope: {reduction_scope}")
281
+
282
+ if linear_cross_entropy is None:
283
+ shift_logits = lm_head(shift_hidden_states)
284
+ loss = torch.nn.functional.cross_entropy(
285
+ shift_logits,
286
+ shift_labels,
287
+ reduction=reduction,
288
+ )
289
+ else:
290
+ loss = linear_cross_entropy(
291
+ shift_hidden_states,
292
+ lm_head.weight,
293
+ shift_labels,
294
+ bias=lm_head.bias,
295
+ reduction=reduction,
296
+ accum_e_fp32=True,
297
+ accum_c_fp32=True,
298
+ )
299
+
300
+ if denominator is not None:
301
+ loss = loss / denominator
302
+ if loss.ndim > 0:
303
+ loss = loss.sum()
304
+
305
+ return loss
306
+
307
+
308
+
309
+ def cross_entropy(pred,
310
+ label,
311
+ weight=None,
312
+ reduction='mean',
313
+ avg_factor=None,
314
+ class_weight=None,
315
+ ignore_index=-100,
316
+ avg_non_ignore=False):
317
+ """Calculate the CrossEntropy loss.
318
+
319
+ Args:
320
+ pred (torch.Tensor): The prediction with shape (N, C), C is the number
321
+ of classes.
322
+ label (torch.Tensor): The learning label of the prediction.
323
+ weight (torch.Tensor, optional): Sample-wise loss weight.
324
+ reduction (str, optional): The method used to reduce the loss.
325
+ avg_factor (int, optional): Average factor that is used to average
326
+ the loss. Defaults to None.
327
+ class_weight (list[float], optional): The weight for each class.
328
+ ignore_index (int | None): The label index to be ignored.
329
+ If None, it will be set to default value. Default: -100.
330
+ avg_non_ignore (bool): The flag decides to whether the loss is
331
+ only averaged over non-ignored targets. Default: False.
332
+
333
+ Returns:
334
+ torch.Tensor: The calculated loss
335
+ """
336
+ # The default value of ignore_index is the same as F.cross_entropy
337
+ ignore_index = -100 if ignore_index is None else ignore_index
338
+ # element-wise losses
339
+ loss = F.cross_entropy(
340
+ pred,
341
+ label,
342
+ weight=class_weight,
343
+ reduction='none',
344
+ ignore_index=ignore_index)
345
+
346
+ # average loss over non-ignored elements
347
+ # pytorch's official cross_entropy average loss over non-ignored elements
348
+ # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
349
+ if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
350
+ avg_factor = label.numel() - (label == ignore_index).sum().item()
351
+
352
+ # apply weights and do the reduction
353
+ if weight is not None:
354
+ weight = weight.float()
355
+ loss = weight_reduce_loss(
356
+ loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
357
+
358
+ return loss
359
+
360
+
361
+ def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index):
362
+ """Expand onehot labels to match the size of prediction."""
363
+ bin_labels = labels.new_full((labels.size(0), label_channels), 0)
364
+ valid_mask = (labels >= 0) & (labels != ignore_index)
365
+ inds = torch.nonzero(
366
+ valid_mask & (labels < label_channels), as_tuple=False)
367
+
368
+ if inds.numel() > 0:
369
+ bin_labels[inds, labels[inds]] = 1
370
+
371
+ valid_mask = valid_mask.view(-1, 1).expand(labels.size(0),
372
+ label_channels).float()
373
+ if label_weights is None:
374
+ bin_label_weights = valid_mask
375
+ else:
376
+ bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels)
377
+ bin_label_weights *= valid_mask
378
+
379
+ return bin_labels, bin_label_weights, valid_mask
380
+
381
+
382
+ def binary_cross_entropy(pred,
383
+ label,
384
+ weight=None,
385
+ reduction='mean',
386
+ avg_factor=None,
387
+ class_weight=None,
388
+ ignore_index=-100,
389
+ avg_non_ignore=False):
390
+ """Calculate the binary CrossEntropy loss.
391
+
392
+ Args:
393
+ pred (torch.Tensor): The prediction with shape (N, 1) or (N, ).
394
+ When the shape of pred is (N, 1), label will be expanded to
395
+ one-hot format, and when the shape of pred is (N, ), label
396
+ will not be expanded to one-hot format.
397
+ label (torch.Tensor): The learning label of the prediction,
398
+ with shape (N, ).
399
+ weight (torch.Tensor, optional): Sample-wise loss weight.
400
+ reduction (str, optional): The method used to reduce the loss.
401
+ Options are "none", "mean" and "sum".
402
+ avg_factor (int, optional): Average factor that is used to average
403
+ the loss. Defaults to None.
404
+ class_weight (list[float], optional): The weight for each class.
405
+ ignore_index (int | None): The label index to be ignored.
406
+ If None, it will be set to default value. Default: -100.
407
+ avg_non_ignore (bool): The flag decides to whether the loss is
408
+ only averaged over non-ignored targets. Default: False.
409
+
410
+ Returns:
411
+ torch.Tensor: The calculated loss.
412
+ """
413
+ # The default value of ignore_index is the same as F.cross_entropy
414
+ ignore_index = -100 if ignore_index is None else ignore_index
415
+
416
+ if pred.dim() != label.dim():
417
+ label, weight, valid_mask = _expand_onehot_labels(
418
+ label, weight, pred.size(-1), ignore_index)
419
+ else:
420
+ # should mask out the ignored elements
421
+ valid_mask = ((label >= 0) & (label != ignore_index)).float()
422
+ if weight is not None:
423
+ # The inplace writing method will have a mismatched broadcast
424
+ # shape error if the weight and valid_mask dimensions
425
+ # are inconsistent such as (B,N,1) and (B,N,C).
426
+ weight = weight * valid_mask
427
+ else:
428
+ weight = valid_mask
429
+
430
+ # average loss over non-ignored elements
431
+ if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
432
+ avg_factor = valid_mask.sum().item()
433
+
434
+ # weighted element-wise losses
435
+ weight = weight.float()
436
+ loss = F.binary_cross_entropy_with_logits(
437
+ pred, label.float(), pos_weight=class_weight, reduction='none')
438
+ # do the reduction for the weighted loss
439
+ loss = weight_reduce_loss(
440
+ loss, weight, reduction=reduction, avg_factor=avg_factor)
441
+
442
+ return loss
443
+
444
+
445
+ def mask_cross_entropy(pred,
446
+ target,
447
+ label,
448
+ reduction='mean',
449
+ avg_factor=None,
450
+ class_weight=None,
451
+ ignore_index=None,
452
+ **kwargs):
453
+ """Calculate the CrossEntropy loss for masks.
454
+
455
+ Args:
456
+ pred (torch.Tensor): The prediction with shape (N, C, *), C is the
457
+ number of classes. The trailing * indicates arbitrary shape.
458
+ target (torch.Tensor): The learning label of the prediction.
459
+ label (torch.Tensor): ``label`` indicates the class label of the mask
460
+ corresponding object. This will be used to select the mask in the
461
+ of the class which the object belongs to when the mask prediction
462
+ if not class-agnostic.
463
+ reduction (str, optional): The method used to reduce the loss.
464
+ Options are "none", "mean" and "sum".
465
+ avg_factor (int, optional): Average factor that is used to average
466
+ the loss. Defaults to None.
467
+ class_weight (list[float], optional): The weight for each class.
468
+ ignore_index (None): Placeholder, to be consistent with other loss.
469
+ Default: None.
470
+
471
+ Returns:
472
+ torch.Tensor: The calculated loss
473
+
474
+ Example:
475
+ >>> N, C = 3, 11
476
+ >>> H, W = 2, 2
477
+ >>> pred = torch.randn(N, C, H, W) * 1000
478
+ >>> target = torch.rand(N, H, W)
479
+ >>> label = torch.randint(0, C, size=(N,))
480
+ >>> reduction = 'mean'
481
+ >>> avg_factor = None
482
+ >>> class_weights = None
483
+ >>> loss = mask_cross_entropy(pred, target, label, reduction,
484
+ >>> avg_factor, class_weights)
485
+ >>> assert loss.shape == (1,)
486
+ """
487
+ assert ignore_index is None, 'BCE loss does not support ignore_index'
488
+ # TODO: handle these two reserved arguments
489
+ assert reduction == 'mean' and avg_factor is None
490
+ num_rois = pred.size()[0]
491
+ inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
492
+ pred_slice = pred[inds, label].squeeze(1)
493
+ return F.binary_cross_entropy_with_logits(
494
+ pred_slice, target, weight=class_weight, reduction='mean')[None]
495
+
496
+
497
+ class CrossEntropyLoss(nn.Module):
498
+
499
+ def __init__(self,
500
+ use_sigmoid=False,
501
+ use_mask=False,
502
+ reduction='mean',
503
+ class_weight=None,
504
+ ignore_index=None,
505
+ loss_weight=1.0,
506
+ avg_non_ignore=False):
507
+ """CrossEntropyLoss.
508
+
509
+ Args:
510
+ use_sigmoid (bool, optional): Whether the prediction uses sigmoid
511
+ of softmax. Defaults to False.
512
+ use_mask (bool, optional): Whether to use mask cross entropy loss.
513
+ Defaults to False.
514
+ reduction (str, optional): . Defaults to 'mean'.
515
+ Options are "none", "mean" and "sum".
516
+ class_weight (list[float], optional): Weight of each class.
517
+ Defaults to None.
518
+ ignore_index (int | None): The label index to be ignored.
519
+ Defaults to None.
520
+ loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
521
+ avg_non_ignore (bool): The flag decides to whether the loss is
522
+ only averaged over non-ignored targets. Default: False.
523
+ """
524
+ super(CrossEntropyLoss, self).__init__()
525
+ assert (use_sigmoid is False) or (use_mask is False)
526
+ self.use_sigmoid = use_sigmoid
527
+ self.use_mask = use_mask
528
+ self.reduction = reduction
529
+ self.loss_weight = loss_weight
530
+ self.class_weight = class_weight
531
+ self.ignore_index = ignore_index
532
+ self.avg_non_ignore = avg_non_ignore
533
+ if ((ignore_index is not None) and not self.avg_non_ignore
534
+ and self.reduction == 'mean'):
535
+ warnings.warn(
536
+ 'Default ``avg_non_ignore`` is False, if you would like to '
537
+ 'ignore the certain label and average loss over non-ignore '
538
+ 'labels, which is the same with PyTorch official '
539
+ 'cross_entropy, set ``avg_non_ignore=True``.')
540
+
541
+ if self.use_sigmoid:
542
+ self.cls_criterion = binary_cross_entropy
543
+ elif self.use_mask:
544
+ self.cls_criterion = mask_cross_entropy
545
+ else:
546
+ self.cls_criterion = cross_entropy
547
+
548
+ def extra_repr(self):
549
+ """Extra repr."""
550
+ s = f'avg_non_ignore={self.avg_non_ignore}'
551
+ return s
552
+
553
+ def forward(self,
554
+ cls_score,
555
+ label,
556
+ weight=None,
557
+ avg_factor=None,
558
+ reduction_override=None,
559
+ ignore_index=None,
560
+ **kwargs):
561
+ """Forward function.
562
+
563
+ Args:
564
+ cls_score (torch.Tensor): The prediction.
565
+ label (torch.Tensor): The learning label of the prediction.
566
+ weight (torch.Tensor, optional): Sample-wise loss weight.
567
+ avg_factor (int, optional): Average factor that is used to average
568
+ the loss. Defaults to None.
569
+ reduction_override (str, optional): The method used to reduce the
570
+ loss. Options are "none", "mean" and "sum".
571
+ ignore_index (int | None): The label index to be ignored.
572
+ If not None, it will override the default value. Default: None.
573
+ Returns:
574
+ torch.Tensor: The calculated loss.
575
+ """
576
+ assert reduction_override in (None, 'none', 'mean', 'sum')
577
+ reduction = (
578
+ reduction_override if reduction_override else self.reduction)
579
+ if ignore_index is None:
580
+ ignore_index = self.ignore_index
581
+
582
+ if self.class_weight is not None:
583
+ class_weight = cls_score.new_tensor(
584
+ self.class_weight, device=cls_score.device)
585
+ else:
586
+ class_weight = None
587
+ loss_cls = self.loss_weight * self.cls_criterion(
588
+ cls_score,
589
+ label,
590
+ weight,
591
+ class_weight=class_weight,
592
+ reduction=reduction,
593
+ avg_factor=avg_factor,
594
+ ignore_index=ignore_index,
595
+ avg_non_ignore=self.avg_non_ignore,
596
+ **kwargs)
597
+ return loss_cls
rynnec/model/predictor/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sam2_predictor import SAM2VideoPredictor
rynnec/model/predictor/sam2_predictor.py ADDED
@@ -0,0 +1,724 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/magic-research/Sa2VA/blob/main/projects/llava_sam2/models/predictor/sam2_predictor.py.
2
+ # Below is the original copyright:
3
+ # coding=utf-8
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ from collections import OrderedDict
18
+
19
+ import torch
20
+ from tqdm import tqdm
21
+
22
+ from ..extension import SAM2Base
23
+ from third_parts.sam2.modeling.sam2_base import NO_OBJ_SCORE
24
+ from third_parts.sam2.utils.misc import fill_holes_in_mask_scores
25
+
26
+
27
+ def _obj_id_to_idx(inference_state, obj_id):
28
+ """Map client-side object id to model-side object index."""
29
+ obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
30
+ if obj_idx is not None:
31
+ return obj_idx
32
+
33
+ # This is a new object id not sent to the server before. We only allow adding
34
+ # new objects *before* the tracking starts.
35
+ allow_new_object = not inference_state["tracking_has_started"]
36
+ if allow_new_object:
37
+ # get the next object slot
38
+ obj_idx = len(inference_state["obj_id_to_idx"])
39
+ inference_state["obj_id_to_idx"][obj_id] = obj_idx
40
+ inference_state["obj_idx_to_id"][obj_idx] = obj_id
41
+ inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
42
+ # set up input and output structures for this object
43
+ inference_state["point_inputs_per_obj"][obj_idx] = {}
44
+ inference_state["mask_inputs_per_obj"][obj_idx] = {}
45
+ inference_state["output_dict_per_obj"][obj_idx] = {
46
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
47
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
48
+ }
49
+ inference_state["temp_output_dict_per_obj"][obj_idx] = {
50
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
51
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
52
+ }
53
+ return obj_idx
54
+ else:
55
+ raise RuntimeError(
56
+ f"Cannot add new object id {obj_id} after tracking starts. "
57
+ f"All existing object ids: {inference_state['obj_ids']}. "
58
+ f"Please call 'reset_state' to restart from scratch."
59
+ )
60
+
61
+
62
+ def _get_maskmem_pos_enc(inference_state, current_out):
63
+ """
64
+ `maskmem_pos_enc` is the same across frames and objects, so we cache it as
65
+ a constant in the inference session to reduce session storage size.
66
+ """
67
+ model_constants = inference_state["constants"]
68
+ # "out_maskmem_pos_enc" should be either a list of tensors or None
69
+ out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
70
+ if out_maskmem_pos_enc is not None:
71
+ if "maskmem_pos_enc" not in model_constants:
72
+ assert isinstance(out_maskmem_pos_enc, list)
73
+ # only take the slice for one object, since it's same across objects
74
+ maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
75
+ model_constants["maskmem_pos_enc"] = maskmem_pos_enc
76
+ else:
77
+ maskmem_pos_enc = model_constants["maskmem_pos_enc"]
78
+ # expand the cached maskmem_pos_enc to the actual batch size
79
+ batch_size = out_maskmem_pos_enc[0].size(0)
80
+ expanded_maskmem_pos_enc = [
81
+ x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
82
+ ]
83
+ else:
84
+ expanded_maskmem_pos_enc = None
85
+ return expanded_maskmem_pos_enc
86
+
87
+
88
+ def _obj_idx_to_id(inference_state, obj_idx):
89
+ """Map model-side object index to client-side object id."""
90
+ return inference_state["obj_idx_to_id"][obj_idx]
91
+
92
+
93
+ def _get_obj_num(inference_state):
94
+ """Get the total number of unique object ids received so far in this session."""
95
+ return len(inference_state["obj_idx_to_id"])
96
+
97
+
98
+ class SAM2VideoPredictor(SAM2Base):
99
+ """The predictor class to handle user interactions and manage inference states."""
100
+
101
+ def __init__(
102
+ self,
103
+ fill_hole_area=0,
104
+ # whether to apply non-overlapping constraints on the output object masks
105
+ non_overlap_masks=False,
106
+ # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
107
+ # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
108
+ clear_non_cond_mem_around_input=False,
109
+ # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
110
+ clear_non_cond_mem_for_multi_obj=False,
111
+ **kwargs,
112
+ ):
113
+ super().__init__(**kwargs)
114
+ self.fill_hole_area = fill_hole_area
115
+ self.non_overlap_masks = non_overlap_masks
116
+ self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
117
+ self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
118
+
119
+ def _get_image_feature(self, inference_state, frame_idx, batch_size):
120
+ """Compute the image features on a given frame."""
121
+ # Look up in the cache first
122
+ image, backbone_out = inference_state["cached_features"].get(
123
+ frame_idx, (None, None)
124
+ )
125
+ if backbone_out is None:
126
+ # Cache miss -- we will run inference on a single image
127
+ image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
128
+ backbone_out = self.forward_image(image)
129
+ # Cache the most recent frame's feature (for repeated interactions with
130
+ # a frame; we can use an LRU cache for more frames in the future).
131
+ inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
132
+
133
+ # expand the features to have the same dimension as the number of objects
134
+ expanded_image = image.expand(batch_size, -1, -1, -1)
135
+ expanded_backbone_out = {
136
+ "backbone_fpn": backbone_out["backbone_fpn"].copy(),
137
+ "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
138
+ }
139
+ for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
140
+ expanded_backbone_out["backbone_fpn"][i] = feat.expand(
141
+ batch_size, -1, -1, -1
142
+ )
143
+ for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
144
+ pos = pos.expand(batch_size, -1, -1, -1)
145
+ expanded_backbone_out["vision_pos_enc"][i] = pos
146
+
147
+ features = self._prepare_backbone_features(expanded_backbone_out)
148
+ features = (expanded_image,) + features
149
+ return features
150
+
151
+
152
+ def _run_single_frame_inference(
153
+ self,
154
+ inference_state,
155
+ output_dict,
156
+ frame_idx,
157
+ batch_size,
158
+ is_init_cond_frame,
159
+ point_inputs,
160
+ mask_inputs,
161
+ reverse,
162
+ run_mem_encoder,
163
+ prev_sam_mask_logits=None,
164
+ ## Extension: LLM prompt
165
+ language_embd=None,
166
+ ):
167
+ """Run tracking on a single frame based on current inputs and previous memory."""
168
+ # Retrieve correct image features
169
+ (
170
+ _,
171
+ _,
172
+ current_vision_feats,
173
+ current_vision_pos_embeds,
174
+ feat_sizes,
175
+ ) = self._get_image_feature(inference_state, frame_idx, batch_size)
176
+
177
+ # point and mask should not appear as input simultaneously on the same frame
178
+ assert point_inputs is None or mask_inputs is None
179
+ current_out = self.track_step(
180
+ frame_idx=frame_idx,
181
+ is_init_cond_frame=is_init_cond_frame,
182
+ current_vision_feats=current_vision_feats,
183
+ current_vision_pos_embeds=current_vision_pos_embeds,
184
+ feat_sizes=feat_sizes,
185
+ point_inputs=point_inputs,
186
+ mask_inputs=mask_inputs,
187
+ output_dict=output_dict,
188
+ num_frames=inference_state["num_frames"],
189
+ track_in_reverse=reverse,
190
+ run_mem_encoder=run_mem_encoder,
191
+ prev_sam_mask_logits=prev_sam_mask_logits,
192
+ language_embd=language_embd,
193
+ )
194
+
195
+ # optionally offload the output to CPU memory to save GPU space
196
+ storage_device = inference_state["storage_device"]
197
+ maskmem_features = current_out["maskmem_features"]
198
+ if maskmem_features is not None:
199
+ maskmem_features = maskmem_features.to(torch.bfloat16)
200
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
201
+ pred_masks_gpu = current_out["pred_masks"]
202
+ # potentially fill holes in the predicted masks
203
+ if self.fill_hole_area > 0:
204
+ pred_masks_gpu = fill_holes_in_mask_scores(
205
+ pred_masks_gpu, self.fill_hole_area
206
+ )
207
+ pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
208
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
209
+ maskmem_pos_enc = _get_maskmem_pos_enc(inference_state, current_out)
210
+ # object pointer is a small tensor, so we always keep it on GPU memory for fast access
211
+ obj_ptr = current_out["obj_ptr"]
212
+ # make a compact version of this frame's output to reduce the state size
213
+ compact_current_out = {
214
+ "maskmem_features": maskmem_features,
215
+ "maskmem_pos_enc": maskmem_pos_enc,
216
+ "pred_masks": pred_masks,
217
+ "obj_ptr": obj_ptr,
218
+ }
219
+ return compact_current_out, pred_masks_gpu
220
+
221
+
222
+ def _consolidate_temp_output_across_obj(
223
+ self,
224
+ inference_state,
225
+ frame_idx,
226
+ is_cond,
227
+ run_mem_encoder,
228
+ consolidate_at_video_res=False,
229
+ ):
230
+ """
231
+ Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
232
+ a frame into a single output for all objects, including
233
+ 1) fill any missing objects either from `output_dict_per_obj` (if they exist in
234
+ `output_dict_per_obj` for this frame) or leave them as placeholder values
235
+ (if they don't exist in `output_dict_per_obj` for this frame);
236
+ 2) if specified, rerun memory encoder after apply non-overlapping constraints
237
+ on the object scores.
238
+ """
239
+ batch_size = _get_obj_num(inference_state)
240
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
241
+ # Optionally, we allow consolidating the temporary outputs at the original
242
+ # video resolution (to provide a better editing experience for mask prompts).
243
+ if consolidate_at_video_res:
244
+ assert not run_mem_encoder, "memory encoder cannot run at video resolution"
245
+ consolidated_H = inference_state["video_height"]
246
+ consolidated_W = inference_state["video_width"]
247
+ consolidated_mask_key = "pred_masks_video_res"
248
+ else:
249
+ consolidated_H = consolidated_W = self.image_size // 4
250
+ consolidated_mask_key = "pred_masks"
251
+
252
+ # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
253
+ # will be added when rerunning the memory encoder after applying non-overlapping
254
+ # constraints to object scores. Its "pred_masks" are prefilled with a large
255
+ # negative value (NO_OBJ_SCORE) to represent missing objects.
256
+ consolidated_out = {
257
+ "maskmem_features": None,
258
+ "maskmem_pos_enc": None,
259
+ consolidated_mask_key: torch.full(
260
+ size=(batch_size, 1, consolidated_H, consolidated_W),
261
+ fill_value=NO_OBJ_SCORE,
262
+ dtype=torch.float32,
263
+ device=inference_state["storage_device"],
264
+ ),
265
+ "obj_ptr": torch.full(
266
+ size=(batch_size, self.hidden_dim),
267
+ fill_value=NO_OBJ_SCORE,
268
+ dtype=torch.float32,
269
+ device=inference_state["device"],
270
+ ),
271
+ }
272
+ empty_mask_ptr = None
273
+ for obj_idx in range(batch_size):
274
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
275
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
276
+ out = obj_temp_output_dict[storage_key].get(frame_idx, None)
277
+ # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
278
+ # we fall back and look up its previous output in "output_dict_per_obj".
279
+ # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
280
+ # "output_dict_per_obj" to find a previous output for this object.
281
+ if out is None:
282
+ out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
283
+ if out is None:
284
+ out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
285
+ # If the object doesn't appear in "output_dict_per_obj" either, we skip it
286
+ # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
287
+ # placeholder above) and set its object pointer to be a dummy pointer.
288
+ if out is None:
289
+ # Fill in dummy object pointers for those objects without any inputs or
290
+ # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
291
+ # i.e. when we need to build the memory for tracking).
292
+ if run_mem_encoder:
293
+ if empty_mask_ptr is None:
294
+ empty_mask_ptr = self._get_empty_mask_ptr(
295
+ inference_state, frame_idx
296
+ )
297
+ # fill object pointer with a dummy pointer (based on an empty mask)
298
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
299
+ continue
300
+ # Add the temporary object output mask to consolidated output mask
301
+ obj_mask = out["pred_masks"]
302
+ consolidated_pred_masks = consolidated_out[consolidated_mask_key]
303
+ if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
304
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
305
+ else:
306
+ # Resize first if temporary object mask has a different resolution
307
+ resized_obj_mask = torch.nn.functional.interpolate(
308
+ obj_mask,
309
+ size=consolidated_pred_masks.shape[-2:],
310
+ mode="bilinear",
311
+ align_corners=False,
312
+ )
313
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
314
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
315
+
316
+ # Optionally, apply non-overlapping constraints on the consolidated scores
317
+ # and rerun the memory encoder
318
+ if run_mem_encoder:
319
+ device = inference_state["device"]
320
+ high_res_masks = torch.nn.functional.interpolate(
321
+ consolidated_out["pred_masks"].to(device, non_blocking=True),
322
+ size=(self.image_size, self.image_size),
323
+ mode="bilinear",
324
+ align_corners=False,
325
+ )
326
+ if self.non_overlap_masks_for_mem_enc:
327
+ high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
328
+ maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
329
+ inference_state=inference_state,
330
+ frame_idx=frame_idx,
331
+ batch_size=batch_size,
332
+ high_res_masks=high_res_masks,
333
+ is_mask_from_pts=True, # these frames are what the user interacted with
334
+ )
335
+ consolidated_out["maskmem_features"] = maskmem_features
336
+ consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
337
+
338
+ return consolidated_out
339
+
340
+
341
+ def _get_orig_video_res_output(self, inference_state, any_res_masks):
342
+ """
343
+ Resize the object scores to the original video resolution (video_res_masks)
344
+ and apply non-overlapping constraints for final output.
345
+ """
346
+ device = inference_state["device"]
347
+ video_H = inference_state["video_height"]
348
+ video_W = inference_state["video_width"]
349
+ any_res_masks = any_res_masks.to(device, non_blocking=True)
350
+ if any_res_masks.shape[-2:] == (video_H, video_W):
351
+ video_res_masks = any_res_masks
352
+ else:
353
+ video_res_masks = torch.nn.functional.interpolate(
354
+ any_res_masks,
355
+ size=(video_H, video_W),
356
+ mode="bilinear",
357
+ align_corners=False,
358
+ )
359
+ if self.non_overlap_masks:
360
+ video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
361
+ return any_res_masks, video_res_masks
362
+
363
+ def init_state(
364
+ self,
365
+ images
366
+ ):
367
+ """Initialize a inference state."""
368
+ inference_state = {}
369
+ inference_state["images"] = images
370
+ inference_state["num_frames"] = len(images)
371
+ # whether to offload the video frames to CPU memory
372
+ # turning on this option saves the GPU memory with only a very small overhead
373
+ inference_state["offload_video_to_cpu"] = False
374
+ # whether to offload the inference state to CPU memory
375
+ # turning on this option saves the GPU memory at the cost of a lower tracking fps
376
+ # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
377
+ # and from 24 to 21 when tracking two objects)
378
+ inference_state["offload_state_to_cpu"] = False
379
+ # the original video height and width, used for resizing final output scores
380
+ inference_state["video_height"] = self.image_size
381
+ inference_state["video_width"] = self.image_size
382
+ inference_state["device"] = torch.device("cuda")
383
+ inference_state["storage_device"] = torch.device("cuda")
384
+ # inputs on each frame
385
+ inference_state["point_inputs_per_obj"] = {}
386
+ inference_state["mask_inputs_per_obj"] = {}
387
+ # visual features on a small number of recently visited frames for quick interactions
388
+ inference_state["cached_features"] = {}
389
+ # values that don't change across frames (so we only need to hold one copy of them)
390
+ inference_state["constants"] = {}
391
+ # mapping between client-side object id and model-side object index
392
+ inference_state["obj_id_to_idx"] = OrderedDict()
393
+ inference_state["obj_idx_to_id"] = OrderedDict()
394
+ inference_state["obj_ids"] = []
395
+ # A storage to hold the model's tracking results and states on each frame
396
+ inference_state["output_dict"] = {
397
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
398
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
399
+ }
400
+ # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
401
+ inference_state["output_dict_per_obj"] = {}
402
+ # A temporary storage to hold new outputs when user interact with a frame
403
+ # to add clicks or mask (it's merged into "output_dict" before propagation starts)
404
+ inference_state["temp_output_dict_per_obj"] = {}
405
+ # Frames that already holds consolidated outputs from click or mask inputs
406
+ # (we directly use their consolidated outputs during tracking)
407
+ inference_state["consolidated_frame_inds"] = {
408
+ "cond_frame_outputs": set(), # set containing frame indices
409
+ "non_cond_frame_outputs": set(), # set containing frame indices
410
+ }
411
+ # metadata for each tracking frame (e.g. which direction it's tracked)
412
+ inference_state["tracking_has_started"] = False
413
+ inference_state["frames_already_tracked"] = {}
414
+ return inference_state
415
+
416
+ def add_language_embd(
417
+ self,
418
+ inference_state,
419
+ frame_idx,
420
+ obj_id,
421
+ language_embd,
422
+ inference=False,
423
+ ):
424
+ obj_idx = _obj_id_to_idx(inference_state, obj_id)
425
+
426
+ is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
427
+ # whether to track in reverse time order
428
+ if is_init_cond_frame:
429
+ reverse = False
430
+ else:
431
+ reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
432
+
433
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
434
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
435
+ # Add a frame to conditioning output if it's an initial conditioning frame or
436
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
437
+ is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
438
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
439
+
440
+ # Get any previously predicted mask logits on this object and feed it along with
441
+ # the new clicks into the SAM mask decoder.
442
+ prev_sam_mask_logits = None
443
+ # lookup temporary output dict first, which contains the most recent output
444
+ # (if not found, then lookup conditioning and non-conditioning frame output)
445
+ prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
446
+ if prev_out is None:
447
+ prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
448
+ if prev_out is None:
449
+ prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
450
+
451
+ if prev_out is not None and prev_out["pred_masks"] is not None:
452
+ prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
453
+ # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
454
+ prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
455
+
456
+ current_out, pred_mask_gpu = self._run_single_frame_inference(
457
+ inference_state=inference_state,
458
+ output_dict=obj_output_dict, # run on the slice of a single object
459
+ frame_idx=frame_idx,
460
+ batch_size=1, # run on the slice of a single object
461
+ is_init_cond_frame=is_init_cond_frame,
462
+ point_inputs=None,
463
+ mask_inputs=None,
464
+ reverse=reverse,
465
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
466
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
467
+ # allows us to enforce non-overlapping constraints on all objects before encoding
468
+ # them into memory.
469
+ run_mem_encoder=False,
470
+ prev_sam_mask_logits=prev_sam_mask_logits,
471
+ ## Extension: LLM prompt
472
+ language_embd=language_embd,
473
+ )
474
+ # Add the output to the output dict (to be used as future memory)
475
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
476
+
477
+ # Resize the output mask to the original video resolution
478
+ obj_ids = inference_state["obj_ids"]
479
+ if inference:
480
+ _consolidated_out = self._consolidate_temp_output_across_obj(
481
+ inference_state,
482
+ frame_idx,
483
+ is_cond=is_cond,
484
+ run_mem_encoder=False,
485
+ consolidate_at_video_res=False,
486
+ )
487
+ # _, video_res_masks = self._get_orig_video_res_output(
488
+ # inference_state, consolidated_out["pred_masks_video_res"]
489
+ # )
490
+ return frame_idx, obj_ids, pred_mask_gpu
491
+
492
+
493
+ def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
494
+ """
495
+ Remove the non-conditioning memory around the input frame. When users provide
496
+ correction clicks, the surrounding frames' non-conditioning memories can still
497
+ contain outdated object appearance information and could confuse the model.
498
+
499
+ This method clears those non-conditioning memories surrounding the interacted
500
+ frame to avoid giving the model both old and new information about the object.
501
+ """
502
+ r = self.memory_temporal_stride_for_eval
503
+ frame_idx_begin = frame_idx - r * self.num_maskmem
504
+ frame_idx_end = frame_idx + r * self.num_maskmem
505
+ output_dict = inference_state["output_dict"]
506
+ non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
507
+ for t in range(frame_idx_begin, frame_idx_end + 1):
508
+ non_cond_frame_outputs.pop(t, None)
509
+ for obj_output_dict in inference_state["output_dict_per_obj"].values():
510
+ obj_output_dict["non_cond_frame_outputs"].pop(t, None)
511
+
512
+ def _run_memory_encoder(
513
+ self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts
514
+ ):
515
+ """
516
+ Run the memory encoder on `high_res_masks`. This is usually after applying
517
+ non-overlapping constraints to object scores. Since their scores changed, their
518
+ memory also need to be computed again with the memory encoder.
519
+ """
520
+ # Retrieve correct image features
521
+ _, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
522
+ inference_state, frame_idx, batch_size
523
+ )
524
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
525
+ current_vision_feats=current_vision_feats,
526
+ feat_sizes=feat_sizes,
527
+ pred_masks_high_res=high_res_masks,
528
+ is_mask_from_pts=is_mask_from_pts,
529
+ )
530
+
531
+ # optionally offload the output to CPU memory to save GPU space
532
+ storage_device = inference_state["storage_device"]
533
+ maskmem_features = maskmem_features.to(torch.bfloat16)
534
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
535
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
536
+ maskmem_pos_enc = _get_maskmem_pos_enc(
537
+ inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
538
+ )
539
+ return maskmem_features, maskmem_pos_enc
540
+
541
+ def _add_output_per_object(
542
+ self, inference_state, frame_idx, current_out, storage_key
543
+ ):
544
+ """
545
+ Split a multi-object output into per-object output slices and add them into
546
+ `output_dict_per_obj`. The resulting slices share the same tensor storage.
547
+ """
548
+ maskmem_features = current_out["maskmem_features"]
549
+ assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
550
+
551
+ maskmem_pos_enc = current_out["maskmem_pos_enc"]
552
+ assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
553
+
554
+ output_dict_per_obj = inference_state["output_dict_per_obj"]
555
+ for obj_idx, obj_output_dict in output_dict_per_obj.items():
556
+ obj_slice = slice(obj_idx, obj_idx + 1)
557
+ obj_out = {
558
+ "maskmem_features": None,
559
+ "maskmem_pos_enc": None,
560
+ "pred_masks": current_out["pred_masks"][obj_slice],
561
+ "obj_ptr": current_out["obj_ptr"][obj_slice],
562
+ }
563
+ if maskmem_features is not None:
564
+ obj_out["maskmem_features"] = maskmem_features[obj_slice]
565
+ if maskmem_pos_enc is not None:
566
+ obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
567
+ obj_output_dict[storage_key][frame_idx] = obj_out
568
+
569
+ @torch.inference_mode()
570
+ def propagate_in_video_preflight(self, inference_state):
571
+ """Prepare inference_state and consolidate temporary outputs before tracking."""
572
+ # Tracking has started and we don't allow adding new objects until session is reset.
573
+ inference_state["tracking_has_started"] = True
574
+ batch_size = _get_obj_num(inference_state)
575
+
576
+ # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
577
+ # add them into "output_dict".
578
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
579
+ output_dict = inference_state["output_dict"]
580
+ # "consolidated_frame_inds" contains indices of those frames where consolidated
581
+ # temporary outputs have been added (either in this call or any previous calls
582
+ # to `propagate_in_video_preflight`).
583
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
584
+ for is_cond in [False, True]:
585
+ # Separately consolidate conditioning and non-conditioning temp outptus
586
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
587
+ # Find all the frames that contain temporary outputs for any objects
588
+ # (these should be the frames that have just received clicks for mask inputs
589
+ # via `add_new_points` or `add_new_mask`)
590
+ temp_frame_inds = set()
591
+ for obj_temp_output_dict in temp_output_dict_per_obj.values():
592
+ temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
593
+ consolidated_frame_inds[storage_key].update(temp_frame_inds)
594
+ # consolidate the temprary output across all objects on this frame
595
+ for frame_idx in temp_frame_inds:
596
+ consolidated_out = self._consolidate_temp_output_across_obj(
597
+ inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
598
+ )
599
+ # merge them into "output_dict" and also create per-object slices
600
+ output_dict[storage_key][frame_idx] = consolidated_out
601
+ self._add_output_per_object(
602
+ inference_state, frame_idx, consolidated_out, storage_key
603
+ )
604
+ clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
605
+ self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
606
+ )
607
+ if clear_non_cond_mem:
608
+ # clear non-conditioning memory of the surrounding frames
609
+ self._clear_non_cond_mem_around_input(inference_state, frame_idx)
610
+
611
+ # clear temporary outputs in `temp_output_dict_per_obj`
612
+ for obj_temp_output_dict in temp_output_dict_per_obj.values():
613
+ obj_temp_output_dict[storage_key].clear()
614
+
615
+ # edge case: if an output is added to "cond_frame_outputs", we remove any prior
616
+ # output on the same frame in "non_cond_frame_outputs"
617
+ for frame_idx in output_dict["cond_frame_outputs"]:
618
+ output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
619
+ for obj_output_dict in inference_state["output_dict_per_obj"].values():
620
+ for frame_idx in obj_output_dict["cond_frame_outputs"]:
621
+ obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
622
+ for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
623
+ assert frame_idx in output_dict["cond_frame_outputs"]
624
+ consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
625
+
626
+ # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
627
+ # with either points or mask inputs (which should be true under a correct workflow).
628
+ all_consolidated_frame_inds = (
629
+ consolidated_frame_inds["cond_frame_outputs"]
630
+ | consolidated_frame_inds["non_cond_frame_outputs"]
631
+ )
632
+ input_frames_inds = set()
633
+ for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
634
+ input_frames_inds.update(point_inputs_per_frame.keys())
635
+ for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
636
+ input_frames_inds.update(mask_inputs_per_frame.keys())
637
+
638
+ # with language embd as input, there may not be point or box
639
+ # assert all_consolidated_frame_inds == input_frames_inds
640
+
641
+ @torch.inference_mode()
642
+ def propagate_in_video(
643
+ self,
644
+ inference_state,
645
+ start_frame_idx=None,
646
+ max_frame_num_to_track=None,
647
+ reverse=False,
648
+ ):
649
+ """Propagate the input points across frames to track in the entire video."""
650
+ self.propagate_in_video_preflight(inference_state)
651
+
652
+ output_dict = inference_state["output_dict"]
653
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
654
+ obj_ids = inference_state["obj_ids"]
655
+ num_frames = inference_state["num_frames"]
656
+ batch_size = _get_obj_num(inference_state)
657
+ if len(output_dict["cond_frame_outputs"]) == 0:
658
+ raise RuntimeError("No points are provided; please add points first")
659
+ clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
660
+ self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
661
+ )
662
+
663
+ # set start index, end index, and processing order
664
+ if start_frame_idx is None:
665
+ # default: start from the earliest frame with input points
666
+ start_frame_idx = min(output_dict["cond_frame_outputs"])
667
+ if max_frame_num_to_track is None:
668
+ # default: track all the frames in the video
669
+ max_frame_num_to_track = num_frames
670
+ if reverse:
671
+ end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
672
+ if start_frame_idx > 0:
673
+ processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
674
+ else:
675
+ processing_order = [] # skip reverse tracking if starting from frame 0
676
+ else:
677
+ end_frame_idx = min(
678
+ start_frame_idx + max_frame_num_to_track, num_frames - 1
679
+ )
680
+ processing_order = range(start_frame_idx, end_frame_idx + 1)
681
+
682
+ for frame_idx in tqdm(processing_order, desc="propagate in video"):
683
+ # We skip those frames already in consolidated outputs (these are frames
684
+ # that received input clicks or mask). Note that we cannot directly run
685
+ # batched forward on them via `_run_single_frame_inference` because the
686
+ # number of clicks on each object might be different.
687
+ if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
688
+ storage_key = "cond_frame_outputs"
689
+ current_out = output_dict[storage_key][frame_idx]
690
+ pred_masks = current_out["pred_masks"]
691
+ if clear_non_cond_mem:
692
+ # clear non-conditioning memory of the surrounding frames
693
+ self._clear_non_cond_mem_around_input(inference_state, frame_idx)
694
+ elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
695
+ storage_key = "non_cond_frame_outputs"
696
+ current_out = output_dict[storage_key][frame_idx]
697
+ pred_masks = current_out["pred_masks"]
698
+ else:
699
+ storage_key = "non_cond_frame_outputs"
700
+ current_out, pred_masks = self._run_single_frame_inference(
701
+ inference_state=inference_state,
702
+ output_dict=output_dict,
703
+ frame_idx=frame_idx,
704
+ batch_size=batch_size,
705
+ is_init_cond_frame=False,
706
+ point_inputs=None,
707
+ mask_inputs=None,
708
+ reverse=reverse,
709
+ run_mem_encoder=True,
710
+ )
711
+ output_dict[storage_key][frame_idx] = current_out
712
+ # Create slices of per-object outputs for subsequent interaction with each
713
+ # individual object after tracking.
714
+ self._add_output_per_object(
715
+ inference_state, frame_idx, current_out, storage_key
716
+ )
717
+ inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
718
+
719
+ # Resize the output mask to the original video resolution (we directly use
720
+ # the mask scores on GPU for output to avoid any CPU conversion in between)
721
+ _, video_res_masks = self._get_orig_video_res_output(
722
+ inference_state, pred_masks
723
+ )
724
+ yield frame_idx, obj_ids, video_res_masks
rynnec/model/processor.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """
21
+ Processor class for VideoLLaMA3.
22
+ """
23
+ from abc import ABCMeta, abstractmethod
24
+ import copy
25
+ import warnings
26
+ from collections import defaultdict
27
+ from typing import List, Union, Dict, Optional, Any
28
+
29
+ import json
30
+ import torch
31
+ from transformers.feature_extraction_utils import BatchFeature
32
+ from transformers.image_utils import ImageInput, VideoInput
33
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
34
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
35
+
36
+ from rynnec.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX
37
+ from rynnec.mm_utils import load_video, load_images
38
+ from rynnec.model.videollama3_encoder.image_processing_videollama3 import is_valid_image, is_valid_video
39
+
40
+
41
+ class Videollama3ProcessorKwargs(ProcessingKwargs, total=False):
42
+ _defaults = {
43
+ "text_kwargs": {
44
+ "padding": False,
45
+ },
46
+ }
47
+
48
+
49
+ class Videollama3BaseProcessor(ProcessorMixin, metaclass=ABCMeta):
50
+ r"""
51
+ Modified from Qwen2VLProcessor
52
+ Args:
53
+ image_processor ([`Qwen2VLImageProcessor`], *optional*):
54
+ The image processor is a required input.
55
+ tokenizer ([`Qwen2TokenizerFast`], *optional*):
56
+ The tokenizer is a required input.
57
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
58
+ in a chat into a tokenizable string.
59
+ """
60
+
61
+ attributes = ["image_processor", "tokenizer"]
62
+ valid_kwargs = ["chat_template", "image_merge_size", "video_merge_size", "fps", "max_frames"]
63
+ image_processor_class = "AutoImageProcessor"
64
+ tokenizer_class = None
65
+ chat_template = None
66
+
67
+ def __init__(
68
+ self,
69
+ image_processor=None,
70
+ tokenizer=None,
71
+ chat_template=None,
72
+ image_merge_size: int = 1,
73
+ video_merge_size: int = 2,
74
+ fps=1,
75
+ max_frames=180,
76
+ **kwargs
77
+ ):
78
+ if chat_template is not None:
79
+ self.chat_template = chat_template
80
+
81
+ self.image_processor = image_processor
82
+ self.tokenizer = tokenizer
83
+ self.image_merge_size = image_merge_size
84
+ self.video_merge_size = video_merge_size
85
+ self.fps = fps
86
+ self.max_frames = max_frames
87
+
88
+ if self.chat_template is not None:
89
+ self.tokenizer.chat_template = self.chat_template
90
+
91
+ self.image_token = DEFAULT_IMAGE_TOKEN
92
+ self.think_start_token = "<think>"
93
+ self.think_end_token = "</think>"
94
+ self.tokenizer.add_tokens([self.image_token], special_tokens=True)
95
+ self.tokenizer.add_tokens([self.think_start_token, self.think_end_token], special_tokens=False)
96
+ self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
97
+ self.think_start_token_id = self.tokenizer.convert_tokens_to_ids(self.think_start_token)
98
+ self.think_end_token_id = self.tokenizer.convert_tokens_to_ids(self.think_end_token)
99
+ self.newline_token_id = self.tokenizer.encode("\n")[0]
100
+
101
+ def load_video(self, *args, **kwargs):
102
+ return load_video(*args, **kwargs)
103
+
104
+ def load_images(self, *args, **kwargs):
105
+ return load_images(*args, **kwargs)
106
+
107
+ def _get_downsampled_grid_sizes(self, image_inputs: Dict[str, Any]):
108
+ grid_sizes = []
109
+ for grid_size, merge_size in zip(image_inputs.get("grid_sizes", []), image_inputs.get("merge_sizes", [])):
110
+ if not torch.all(grid_size[1:] % merge_size == 0):
111
+ warnings.warn(f"Grid size {grid_size} is not divisible by merge size. Some undesired errors may occur.")
112
+ if grid_size[0] == 1:
113
+ grid_sizes.append(grid_size[1:] / merge_size)
114
+ elif grid_size[0] > 1:
115
+ grid_sizes.extend([grid_size[1:] / merge_size] * grid_size[0])
116
+ return grid_sizes
117
+
118
+ def _get_visual_seq_len(self, grid_size: torch.Tensor):
119
+ num_tokens = int(grid_size.prod().item())
120
+ return num_tokens
121
+
122
+ @abstractmethod
123
+ def _process_text_with_label(
124
+ self,
125
+ text: List[Dict],
126
+ grid_sizes: torch.Tensor = None,
127
+ **kwargs,
128
+ ):
129
+ return {}
130
+
131
+ def _process_text_without_label(
132
+ self,
133
+ text: Union[List[str], List[Dict]],
134
+ grid_sizes: torch.Tensor = None,
135
+ **kwargs,
136
+ ):
137
+ if isinstance(text, (list, tuple)) and isinstance(text[0], dict):
138
+ warnings.warn("Input text is a list of messages. Automatically convert it to a string with 'apply_chat_template' with generation prompt.")
139
+ text = self.apply_chat_template(text, tokenize=False, add_generation_prompt=True)
140
+
141
+ if len(grid_sizes) > 0:
142
+ image_idx = 0
143
+ while self.image_token in text:
144
+ thw = grid_sizes[image_idx]
145
+ text = text.replace(self.image_token, "<placeholder>" * thw.prod().long(), 1)
146
+ image_idx += 1
147
+ text = text.replace("<placeholder>", self.image_token)
148
+ assert len(grid_sizes) == image_idx, "Number of images does not match the number of image tokens in the text."
149
+
150
+ text_inputs = self.tokenizer(text, **kwargs)
151
+ return text_inputs
152
+
153
+ def process_text(
154
+ self,
155
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], List[Dict]],
156
+ image_inputs: Dict[str, torch.Tensor] = {},
157
+ return_labels: bool = False,
158
+ **kwargs,
159
+ ):
160
+ kwargs.pop("padding", None)
161
+ kwargs.pop("padding_side", None)
162
+
163
+ grid_sizes = []
164
+ for grid_size, merge_size in zip(image_inputs.get("grid_sizes", []), image_inputs.get("merge_sizes", [])):
165
+ if not torch.all(grid_size[1:] % merge_size == 0):
166
+ warnings.warn(f"Grid size {grid_size} is not divisible by merge size. Some undesired errors may occur.")
167
+ if grid_size[0] == 1:
168
+ grid_sizes.append(grid_size[1:] / merge_size)
169
+ elif grid_size[0] > 1:
170
+ grid_sizes.extend([grid_size[1:] / merge_size] * grid_size[0])
171
+
172
+ if return_labels:
173
+ return self._process_text_with_label(text, grid_sizes, **kwargs)
174
+ return self._process_text_without_label(text, grid_sizes, **kwargs)
175
+
176
+ def process_images(
177
+ self,
178
+ images: ImageInput = None,
179
+ merge_size: Optional[int] = 1,
180
+ **kwargs,
181
+ ):
182
+ if images is None:
183
+ return {}
184
+ image_inputs = self.image_processor(images=images, merge_size=merge_size, **kwargs)
185
+ return image_inputs
186
+
187
+ def __call__(
188
+ self,
189
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], List[Dict]] = None,
190
+ images: ImageInput = None,
191
+ merge_size: Optional[int] = 1,
192
+ return_labels: bool = False,
193
+ **kwargs: Unpack[Videollama3ProcessorKwargs],
194
+ ) -> BatchFeature:
195
+ """
196
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
197
+ and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
198
+ the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
199
+ Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
200
+
201
+ Args:
202
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
203
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
204
+ tensor. Both channels-first and channels-last formats are supported.
205
+ text (`str`, `List[str]`, `List[List[str]]`):
206
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
207
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
208
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
209
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
210
+ If set, will return tensors of a particular framework. Acceptable values are:
211
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
212
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
213
+ - `'np'`: Return NumPy `np.ndarray` objects.
214
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
215
+
216
+ Returns:
217
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
218
+
219
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
220
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
221
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
222
+ `None`).
223
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
224
+ - **grid_sizes** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
225
+ """
226
+ output_kwargs = self._merge_kwargs(
227
+ Videollama3ProcessorKwargs,
228
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
229
+ **kwargs,
230
+ )
231
+ output_kwargs["text_kwargs"].pop("padding", None)
232
+ output_kwargs["text_kwargs"].pop("padding_side", None)
233
+
234
+ image_inputs = self.process_images(images, merge_size, **output_kwargs["images_kwargs"])
235
+ text_inputs = self.process_text(text, image_inputs, return_labels, **output_kwargs["text_kwargs"])
236
+
237
+ return BatchFeature(data={**text_inputs, **image_inputs})
238
+
239
+ def batch_decode(self, *args, **kwargs):
240
+ """
241
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
242
+ refer to the docstring of this method for more information.
243
+ """
244
+ return self.tokenizer.batch_decode(*args, **kwargs)
245
+
246
+ def decode(self, *args, **kwargs):
247
+ """
248
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
249
+ the docstring of this method for more information.
250
+ """
251
+ return self.tokenizer.decode(*args, **kwargs)
252
+
253
+ def _load_multimodal_data(self, conversation: List[Dict[str, Any]]):
254
+ multimodal_info = defaultdict(list)
255
+ new_conversation = []
256
+ for message in conversation:
257
+ new_message = {"role": message["role"]}
258
+ if not isinstance(message["content"], (list, tuple)):
259
+ new_message["content"] = message["content"]
260
+ new_conversation.append(new_message)
261
+ continue
262
+
263
+ new_contents = []
264
+ for content in message["content"]:
265
+ if not isinstance(content, dict):
266
+ new_contents.append(content)
267
+ continue
268
+ assert "type" in content, "Content must have 'type' field."
269
+ if content["type"] in ["image", "video"] and content["type"] in content and isinstance(content[content["type"]], dict):
270
+ # TODO: support other types which are not compatible with json
271
+ load_args = content[content["type"]]
272
+ data_id = json.dumps({k: v for k, v in load_args.items() if not k in ["start_time", "end_time"]})
273
+ new_content = copy.deepcopy(content)
274
+ multimodal_info[data_id].append(new_content)
275
+ new_contents.append(new_content)
276
+ else:
277
+ new_contents.append(content)
278
+
279
+ new_message["content"] = new_contents
280
+ new_conversation.append(new_message)
281
+
282
+ for data_id, contents in multimodal_info.items():
283
+ data_type = contents[0]["type"]
284
+ if data_type == "image":
285
+ image = self.load_images(contents[0][data_type]["image_path"])[0]
286
+ for content in contents:
287
+ content["image"] = image.copy()
288
+
289
+ elif data_type == "video":
290
+ # TODO: start_time is None?
291
+ start_times = [content["video"].get("start_time", 0.) for content in contents]
292
+ end_times = [content["video"].get("end_time", float("inf")) for content in contents]
293
+
294
+ load_args = contents[0][data_type]
295
+ start_time, end_time = min(start_times), max(end_times)
296
+ if start_time > 0:
297
+ load_args["start_time"] = start_time
298
+ if end_time < float("inf"):
299
+ load_args["end_time"] = end_time
300
+ images, timestamps = self.load_video(**load_args)
301
+
302
+ for content, start_time, end_time in zip(contents, start_times, end_times):
303
+ cur_images, cur_timestamps = [], []
304
+ for image, timestamp in zip(images, timestamps):
305
+ if start_time <= timestamp <= end_time:
306
+ cur_images.append(image.copy())
307
+ cur_timestamps.append(timestamp)
308
+
309
+ content[data_type] = cur_images
310
+ content["num_frames"] = len(cur_images)
311
+ content["timestamps"] = cur_timestamps
312
+
313
+ return new_conversation
314
+
315
+ def _gather_multimodal_data(self, conversation: List[Dict[str, Any]]):
316
+ images = []
317
+ for message in conversation:
318
+ if not isinstance(message["content"], (list, tuple)):
319
+ continue
320
+ for content in message["content"]:
321
+ if not isinstance(content, dict):
322
+ continue
323
+ if content["type"] == "video":
324
+ video = content["video"]
325
+ assert is_valid_video(video), f"Invalid video data: {video}."
326
+ images.append(video)
327
+ if content["type"] == "image":
328
+ image = content["image"]
329
+ assert is_valid_image(image), f"Invalid image data: {image}."
330
+ images.append(image)
331
+ images = images if len(images) > 0 else None
332
+ return images
333
+
334
+ def apply_chat_template(
335
+ self,
336
+ conversation: List[Dict[str, Any]],
337
+ chat_template: Optional[str] = None,
338
+ tokenize: bool = False,
339
+ add_system_prompt: bool = False,
340
+ add_generation_prompt: bool = False,
341
+ add_think_prompt: bool = False,
342
+ return_dict: bool = False,
343
+ **kwargs,
344
+ ) -> str:
345
+ """
346
+ Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input
347
+ conversations to turn them into a single tokenizable string.
348
+ Args:
349
+ conversation (`List[Dict, str, str]`):
350
+ The conversation to format.
351
+ chat_template (`Optional[str]`, *optional*):
352
+ The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
353
+ chat template is used.
354
+ tokenize (`bool`, *optional*, defaults to `False`):
355
+ Whether to tokenize the output or not.
356
+ add_system_prompt (`bool`, *optional*, defaults to `False`):
357
+ Whether to add the system prompt to the output or not.
358
+ add_generation_prompt (`bool`, *optional*, defaults to `False`):
359
+ Whether to add the generation prompt to the output or not.
360
+ image_token (`Optional[str]`, *optional*, defaults to `<image>`):
361
+ The token to use for indicating images in the conversation.
362
+ **kwargs:
363
+ Additional keyword arguments
364
+ """
365
+
366
+ if chat_template is None:
367
+ if self.chat_template is not None:
368
+ chat_template = self.chat_template
369
+ else:
370
+ raise ValueError(
371
+ "No chat template is set for this processor. Please either set the `chat_template` attribute, "
372
+ "or provide a chat template as an argument. See "
373
+ "https://huggingface.co/docs/transformers/main/en/chat_templating for more information."
374
+ )
375
+
376
+ images = None
377
+ if return_dict:
378
+ conversation = self._load_multimodal_data(conversation)
379
+ images = self._gather_multimodal_data(conversation)
380
+
381
+ prompt = self.tokenizer.apply_chat_template(
382
+ conversation,
383
+ chat_template=chat_template,
384
+ tokenize=tokenize,
385
+ add_system_prompt=add_system_prompt,
386
+ add_generation_prompt=add_generation_prompt,
387
+ add_think_prompt=add_think_prompt,
388
+ image_token=self.image_token,
389
+ **kwargs
390
+ )
391
+
392
+ out = {"text": prompt, "images": images}
393
+ if return_dict:
394
+ return out
395
+ return out["text"]
396
+
397
+ @property
398
+ def model_input_names(self):
399
+ tokenizer_input_names = self.tokenizer.model_input_names
400
+ image_processor_input_names = self.image_processor.model_input_names
401
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
rynnec/model/projector.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Alibaba DAMO Academy
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import os
17
+ import re
18
+
19
+ import einops
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from timm.models.layers import LayerNorm, LayerNorm2d
24
+ from timm.models.regnet import RegStage
25
+ from transformers import TRANSFORMERS_CACHE
26
+
27
+
28
+ def parse_snapshot_folder(repo_id, cache_dir=None, repo_type="model"):
29
+ revision = "main"
30
+ # 1. parse the downloaded cache folder
31
+ if cache_dir is None:
32
+ cache_dir = TRANSFORMERS_CACHE
33
+ else:
34
+ cache_dir = cache_dir
35
+ object_id = repo_id.replace("/", "--")
36
+ repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}")
37
+ # 2. resolve refs (for instance to convert main to the associated commit sha)
38
+ refs_dir = os.path.join(repo_cache, "refs")
39
+ if os.path.isdir(refs_dir):
40
+ revision_file = os.path.join(refs_dir, revision)
41
+ if os.path.isfile(revision_file):
42
+ with open(revision_file) as f:
43
+ revision = f.read()
44
+ # 3. acquire the snapshot folder
45
+ folder = os.path.join(repo_cache, "snapshots", revision)
46
+
47
+ return folder
48
+
49
+
50
+ def load_mm_projector(model_path, cache_dir=None, token=None):
51
+ if os.path.exists(os.path.join(model_path, 'mm_projector.bin')):
52
+ is_local = True
53
+ folder = model_path
54
+ else:
55
+ is_local = False
56
+ folder = parse_snapshot_folder(model_path, cache_dir=cache_dir, repo_type="model")
57
+ if not os.path.exists(os.path.join(folder, 'mm_projector.bin')):
58
+ # downloading from remote repo
59
+ from huggingface_hub import snapshot_download
60
+ snapshot_download(repo_id=model_path, cache_dir=cache_dir, token=token)
61
+
62
+ mm_projector_weights = torch.load(os.path.join(folder, 'mm_projector.bin'), map_location='cpu')
63
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
64
+ return mm_projector_weights
65
+
66
+
67
+ class IdentityMap(nn.Module):
68
+
69
+ def __init__(self):
70
+ super().__init__()
71
+
72
+ def forward(self, x, *args, **kwargs):
73
+ return x
74
+
75
+ @property
76
+ def config(self):
77
+ return {"mm_projector_type": 'identity'}
78
+
79
+
80
+ def build_mlp(depth, hidden_size, output_hidden_size):
81
+ modules = [nn.Linear(hidden_size, output_hidden_size)]
82
+ for _ in range(1, depth):
83
+ modules.append(nn.GELU())
84
+ modules.append(nn.Linear(output_hidden_size, output_hidden_size))
85
+ return nn.Sequential(*modules)
86
+
87
+
88
+ class SimSpatialConv(nn.Module):
89
+
90
+ def __init__(self, mm_hidden_size, hidden_size, downsample=(2, 2), padding=1, depth=1, mlp_depth=2):
91
+ super().__init__()
92
+ self.encoder_hidden_size = encoder_hidden_size = mm_hidden_size
93
+ self.output_hidden_size = output_hidden_size = hidden_size
94
+ self.downsample = downsample
95
+ self.padding = padding
96
+ self.sampler = nn.Sequential(
97
+ nn.Conv2d(
98
+ in_channels=self.encoder_hidden_size,
99
+ out_channels=4 * self.encoder_hidden_size,
100
+ kernel_size=self.downsample,
101
+ stride=self.downsample,
102
+ padding=self.padding,
103
+ bias=True
104
+ ),
105
+ nn.SiLU(),
106
+ )
107
+ self.readout = build_mlp(mlp_depth, 4 * self.encoder_hidden_size, self.output_hidden_size)
108
+
109
+ def forward(self, x):
110
+ hw = int(x.size(1) ** 0.5)
111
+ x = einops.rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw)
112
+ x = self.sampler(x)
113
+ x = einops.rearrange(x, "b d h w -> b (h w) d")
114
+ x = self.readout(x)
115
+ return x
116
+
117
+ def cal_proj_size(self, input_size):
118
+ if isinstance(input_size, int):
119
+ input_size = (input_size, input_size)
120
+ height = math.ceil((input_size[0] + self.padding) / self.downsample[0])
121
+ width = math.ceil((input_size[1] + self.padding) / self.downsample[1])
122
+ return height * width
123
+
124
+
125
+ class MlpGeluProjector(nn.Module):
126
+ def __init__(self, mm_hidden_size, hidden_size, projector_type):
127
+ super().__init__()
128
+
129
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
130
+ mlp_depth = int(mlp_gelu_match.group(1))
131
+
132
+ self.readout = build_mlp(mlp_depth, mm_hidden_size, hidden_size)
133
+
134
+ def forward(self, x):
135
+ x = self.readout(x)
136
+ return x
137
+
138
+ def cal_proj_size(self, input_size):
139
+ if isinstance(input_size, int):
140
+ input_size = (input_size, input_size)
141
+ height = input_size[0]
142
+ width = input_size[1]
143
+ return height * width
144
+
145
+
146
+ def build_vision_projector(config, mm_hidden_size, delay_load=False, **kwargs):
147
+ # rynnec projector only support image-wise operation now, i.e., prohibit the temporal aggregation
148
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
149
+ hidden_size = config.hidden_size
150
+
151
+ if projector_type == "linear":
152
+ # NOTE: for both linear and mlp2x_gelu projector type, mean pooling is adopted to aggreate video features
153
+ return nn.Linear(mm_hidden_size, hidden_size)
154
+ elif projector_type == "simp_spatial_conv":
155
+ return SimSpatialConv(mm_hidden_size, hidden_size)
156
+ elif projector_type.startswith("mlp"):
157
+ return MlpGeluProjector(mm_hidden_size, hidden_size, projector_type)
158
+ if projector_type == 'identity':
159
+ return IdentityMap()
160
+
161
+ raise ValueError(f'Unknown projector type: {projector_type}')
rynnec/model/region_encoder.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from functools import partial
5
+
6
+ class MaskExtractor(nn.Module):
7
+ def __init__(self, config, mm_hidden_size, depth=2):
8
+ super(MaskExtractor, self).__init__()
9
+ self.mask_pooling = MaskPooling()
10
+ modules = [nn.Linear(mm_hidden_size, config.hidden_size)]
11
+ for _ in range(1, depth):
12
+ modules.append(nn.GELU())
13
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
14
+ self.feat_linear = nn.Sequential(*modules)
15
+
16
+ def forward(self, feats, masks):
17
+ query_feats = []
18
+
19
+ if masks is None: #infer
20
+ return None
21
+
22
+ num_imgs = len(masks)
23
+ region_token_nums = []
24
+ image_idx = 0
25
+ for idx in range(num_imgs):
26
+ if masks[idx]==None:
27
+ continue
28
+ for mask_idx in range(len(masks[idx])):
29
+ mask = masks[idx][mask_idx].unsqueeze(0).unsqueeze(0).float()
30
+ if len(mask[0])==0:
31
+ mask = torch.zeros((1, 1, 336, 336)).to(feats.device).float()
32
+
33
+ feat = feats[image_idx].unsqueeze(0)
34
+ image_idx+=1
35
+
36
+ # h, w = feat.shape[1:3]
37
+ feat = feat.permute(0,3,1,2)
38
+
39
+ raw_dtype = feat.dtype
40
+ feat = feat.to(mask.dtype)
41
+
42
+ mask_feat_raw = self.mask_pooling(feat, mask) # [n, 1024]
43
+
44
+ query_feats.append(mask_feat_raw)
45
+ if len(query_feats)==0:
46
+ return None
47
+ mask_feats = torch.cat(query_feats, dim=0)
48
+ mask_feats = mask_feats.to(feats[0].dtype)
49
+ mask_feats_linear = self.feat_linear(mask_feats)
50
+ return mask_feats_linear
51
+
52
+
53
+ class MaskPooling(nn.Module):
54
+ def __init__(self):
55
+ super().__init__()
56
+
57
+ def forward(self, x, mask):
58
+
59
+ if not x.shape[-2:] == mask.shape[-2:]:
60
+ # reshape mask to x
61
+ mask = F.interpolate(mask, size=x.shape[-2:], mode='bilinear', align_corners=False)
62
+
63
+ # b, c, h ,w = x.shape
64
+ # b, q, h, w = mask.shape
65
+ mask = (mask > 0).to(mask.dtype)
66
+ mask = mask.permute(1,0,2,3)
67
+ denorm = mask.sum(dim=(-1, -2), keepdim=True) + 1e-8
68
+
69
+ mask_pooled_x = (x * mask/denorm).sum(-1).sum(-1)
70
+
71
+ return mask_pooled_x
72
+
73
+
74
+ def build_region_encoder(config, mm_hidden_size):
75
+
76
+ return MaskExtractor(config, mm_hidden_size)
77
+
rynnec/model/rynnec_arch.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from: https://github.com/DAMO-NLP-SG/VideoLLaMA3.
2
+ # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
3
+ # Copyright 2023 Haotian Liu
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ import math
19
+ from abc import ABC, abstractmethod
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import einops
23
+ import torch
24
+ import torch.distributed as dist
25
+ import torch.nn as nn
26
+ import numpy as np
27
+
28
+ from ..constants import IGNORE_INDEX, MODAL_INDEX_MAP, NUM_FRAMES
29
+ from .encoder import build_vision_encoder
30
+ from .projector import build_vision_projector, load_mm_projector
31
+ from .region_encoder import build_region_encoder
32
+ from ..mm_utils import reshape_images_to_raw_grid
33
+
34
+
35
+ def spatial_downsampling(features, grid_thws, stride=2):
36
+ n, c = features.shape
37
+
38
+ flatten_grid_thws = torch.cat([grid_thw for batch_grid_thws in grid_thws for grid_thw in batch_grid_thws])
39
+ split_sizes = [grid_thw.prod() for grid_thw in flatten_grid_thws]
40
+ features = torch.split(features, split_sizes)
41
+
42
+ new_features = []
43
+ for feature, grid_thw in zip(features, flatten_grid_thws):
44
+ # NOTE: adapted for reshape in image processor
45
+ feature = feature.view(grid_thw[0], grid_thw[1] // stride, grid_thw[2] // stride, stride, stride, c).permute(0, 1, 3, 2, 4, 5)
46
+ feature = feature.reshape(grid_thw[0], grid_thw[1], grid_thw[2], c).permute(0, 3, 1, 2)
47
+ # NOTE: previous version model is align_corners=True
48
+ new_feature = torch.nn.functional.interpolate(feature, (math.ceil(grid_thw[1] / stride), math.ceil(grid_thw[2] / stride)), mode='bilinear')
49
+ # new_feature = nn.functional.avg_pool2d(feature, stride)
50
+ # new_feature = nn.functional.max_pool2d(feature, stride)
51
+ new_features.append(new_feature.permute(0, 2, 3, 1).view(-1, c))
52
+ new_features = torch.cat(new_features)
53
+
54
+ return new_features
55
+
56
+
57
+ class RynnecMetaModel:
58
+
59
+ def __init__(self, config):
60
+ super(RynnecMetaModel, self).__init__(config)
61
+
62
+ if hasattr(config, "vision_encoder") or hasattr(config, "mm_vision_encoder"):
63
+ self.vision_encoder = build_vision_encoder(config, delay_load=False)
64
+ self.mm_projector = build_vision_projector(config, config.mm_hidden_size)
65
+ self.region_encoder = build_region_encoder(config, config.mm_hidden_size)
66
+
67
+ def get_vision_encoder(self):
68
+ vision_encoder = getattr(self, 'vision_encoder', None)
69
+ if type(vision_encoder) is list:
70
+ vision_encoder = vision_encoder[0]
71
+ return vision_encoder
72
+
73
+ def get_mm_projector(self):
74
+ return self.mm_projector
75
+
76
+ def initialize_vision_modules(self, model_args, fsdp=None):
77
+ vision_encoder = model_args.vision_encoder
78
+ mm_vision_select_layer = model_args.mm_vision_select_layer
79
+ mm_vision_select_feature = model_args.mm_vision_select_feature
80
+ pretrain_mm_projector = model_args.pretrain_mm_projector
81
+
82
+ self.config.mm_vision_encoder = vision_encoder
83
+
84
+ if self.get_vision_encoder() is None:
85
+ vision_encoder = build_vision_encoder(model_args)
86
+
87
+ if fsdp is not None and len(fsdp) > 0:
88
+ self.vision_encoder = [vision_encoder]
89
+ else:
90
+ self.vision_encoder = vision_encoder
91
+ else:
92
+ if fsdp is not None and len(fsdp) > 0:
93
+ vision_encoder = self.vision_encoder[0]
94
+ else:
95
+ vision_encoder = self.vision_encoder
96
+ # NOTE: only compatible with delay_load encoder
97
+ # vision_encoder.load_model(vision_encoder.cfg_only)
98
+
99
+ self.config.use_mm_proj = True
100
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
101
+ self.config.mm_hidden_size = vision_encoder.hidden_size
102
+ self.config.mm_vision_select_layer = mm_vision_select_layer
103
+ self.config.mm_vision_select_feature = mm_vision_select_feature
104
+
105
+ if getattr(self, 'mm_projector', None) is None:
106
+ self.mm_projector = build_vision_projector(self.config)
107
+ else:
108
+ # In case it is frozen by LoRA
109
+ for p in self.mm_projector.parameters():
110
+ p.requires_grad = True
111
+
112
+ if pretrain_mm_projector is not None:
113
+ if os.path.exists(pretrain_mm_projector):
114
+ is_local = True
115
+ if os.path.isdir(pretrain_mm_projector):
116
+ mm_projector_weights = load_mm_projector(pretrain_mm_projector)
117
+ else:
118
+ mm_projector_weights = torch.load(pretrain_mm_projector, map_location='cpu')
119
+ else:
120
+ # Support loading projector weights from remote HuggingFace model hub
121
+ is_local = False
122
+ pretrain_mm_projector = pretrain_mm_projector.replace('mm_projector.bin', '')
123
+ pretrain_mm_projector = pretrain_mm_projector.strip('/').strip('\\').strip()
124
+ mm_projector_weights = load_mm_projector(pretrain_mm_projector)
125
+
126
+ def get_w(weights, keyword):
127
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
128
+
129
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'), strict=False)
130
+
131
+
132
+ class RynnecMetaForCausalLM(ABC):
133
+
134
+ @abstractmethod
135
+ def get_model(self):
136
+ pass
137
+
138
+ def num_frames(self):
139
+ if hasattr(self.config, 'num_frames'):
140
+ return self.config.num_frames
141
+ else:
142
+ return NUM_FRAMES
143
+
144
+ def spatial_merge_size(self):
145
+ if hasattr(self.config, 'spatial_merge_size'):
146
+ return self.config.spatial_merge_size
147
+ else:
148
+ return 1
149
+
150
+ def get_vision_encoder(self):
151
+ return self.get_model().get_vision_encoder()
152
+
153
+ def get_mm_projector(self):
154
+ return self.get_model().get_mm_projector()
155
+
156
+ def encode_images(
157
+ self,
158
+ pixel_values: torch.FloatTensor,
159
+ grid_sizes: torch.LongTensor,
160
+ merge_sizes: torch.LongTensor,
161
+ ):
162
+ mm_features, mm_features_raw = self.get_model().get_vision_encoder()(
163
+ pixel_values=pixel_values,
164
+ grid_sizes=grid_sizes,
165
+ merge_sizes=merge_sizes,
166
+ )
167
+ mm_features = self.get_model().mm_projector(mm_features)
168
+ return mm_features, mm_features_raw
169
+
170
+ def _get_valid_visual_tokens(
171
+ self,
172
+ mm_features: torch.FloatTensor,
173
+ batched_num_patches: torch.LongTensor,
174
+ modals: List[str],
175
+ ):
176
+ valid_masks = []
177
+ for num_patches, modal in zip(batched_num_patches, modals):
178
+ valid_mask = torch.full((num_patches, ), modal != "text", dtype=torch.bool, device=mm_features.device)
179
+ valid_masks.append(valid_mask)
180
+ mm_features = mm_features[torch.cat(valid_masks)]
181
+ return mm_features
182
+
183
+ def prepare_inputs_labels_for_multimodal(
184
+ self,
185
+ input_ids: torch.LongTensor = None,
186
+ attention_mask: Optional[torch.Tensor] = None,
187
+ position_ids: Optional[torch.LongTensor] = None,
188
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
189
+ labels: Optional[torch.LongTensor] = None,
190
+ pixel_values: Optional[torch.FloatTensor] = None,
191
+ grid_sizes: Optional[torch.LongTensor] = None,
192
+ merge_sizes: Optional[torch.LongTensor] = None,
193
+ modals: Optional[List[str]] = None,
194
+ masks=None,
195
+ mask_ids = None
196
+ ):
197
+ vision_encoder = self.get_vision_encoder()
198
+ # NOTE: text-only situation
199
+ if vision_encoder is None or pixel_values is None or input_ids.shape[1] == 1:
200
+ return input_ids, attention_mask, position_ids, past_key_values, None, labels
201
+
202
+ # 1. flatten text inputs
203
+ B, N = input_ids.shape
204
+ input_ids = input_ids.view(B * N)
205
+ if attention_mask is not None:
206
+ attention_mask = attention_mask.view(B * N)
207
+ if position_ids is not None:
208
+ position_ids = position_ids.view(B * N)
209
+ if labels is not None:
210
+ labels = labels.view(B * N)
211
+
212
+ # 2. embed visual tokens
213
+ batched_num_patches = grid_sizes.prod(dim=1).div(merge_sizes ** 2).long()
214
+
215
+ mm_features, mm_features_raw = self.encode_images(pixel_values, grid_sizes, merge_sizes)
216
+ mm_features = mm_features.to(input_ids.device)
217
+ mm_features_raw = mm_features_raw.to(input_ids.device)
218
+ mm_features = self._get_valid_visual_tokens(mm_features, batched_num_patches, modals)
219
+
220
+ # 3. embed text tokens
221
+ image_selected = (input_ids == self.config.image_token_index)
222
+ # input_ids[image_selected] = 0
223
+ inputs_embeds = self.get_model().embed_tokens(input_ids).clone()
224
+
225
+ num_vision_tokens = image_selected.sum()
226
+ if mm_features.size(0) > num_vision_tokens:
227
+ print(f"Number of mm_features ({mm_features.size(0)}) exceeds the number of image tokens ({num_vision_tokens}). Automative truncated.")
228
+ mm_features = mm_features[:num_vision_tokens]
229
+
230
+ # 4. replace multimodal tokens with features
231
+ inputs_embeds[image_selected] = inputs_embeds[image_selected] * 0.0 + mm_features
232
+
233
+ # 5. embed region tokens
234
+ try:
235
+
236
+ mask_selected = (input_ids == self.config.region_token_index)
237
+
238
+ if mask_selected.sum() > 0:
239
+ reshaped_features = reshape_images_to_raw_grid(mm_features_raw, grid_sizes)
240
+ mask_additional_image_features = []
241
+ idx = 0
242
+ new_masks = []
243
+ for bs in range(len(masks)):
244
+ flag=True
245
+ for ml in range(len(masks[bs])):
246
+ if mask_ids[idx]>=0:
247
+ mask_additional_image_features.append(reshaped_features[mask_ids[idx]])
248
+ else:
249
+ flag=False
250
+ idx+=1
251
+ if flag:
252
+ new_masks.append(masks[bs])
253
+
254
+ mask_feats = self.get_model().region_encoder(mask_additional_image_features, new_masks)
255
+ inputs_embeds[mask_selected] = inputs_embeds[mask_selected]*0.0 + mask_feats
256
+
257
+ except Exception as e:
258
+ print(e)
259
+
260
+
261
+ # 6. reshape back to batched format
262
+ C = inputs_embeds.shape[-1]
263
+ inputs_embeds = inputs_embeds.reshape(B, -1, C)
264
+ if attention_mask is not None:
265
+ attention_mask = attention_mask.view(B, -1)
266
+ if labels is not None:
267
+ labels = labels.view(B, -1)
268
+ if position_ids is not None:
269
+ position_ids = position_ids.view(B, -1)
270
+
271
+ return None, attention_mask, position_ids, past_key_values, inputs_embeds, labels
rynnec/model/rynnec_qwen2.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from: https://github.com/DAMO-NLP-SG/VideoLLaMA3.
2
+ # Adopted from: https://github.com/haotian-liu/LLaVA.
3
+ # Below is the original copyright:
4
+ # Copyright 2023 Haotian Liu
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
19
+ from typing import List, Optional, Tuple, Union, Dict
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from transformers import (AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoImageProcessor,
25
+ Qwen2Config, Qwen2ForCausalLM, Qwen2Model)
26
+ from transformers.generation.utils import GenerateOutput
27
+ # from transformers.modeling_outputs import CausalLMOutputWithPast
28
+ from dataclasses import dataclass
29
+ from transformers.utils import ModelOutput
30
+
31
+ from .loss import cross_entropy_loss, CrossEntropyLoss, DiceLoss
32
+ from .processor import Videollama3BaseProcessor
33
+ from .rynnec_arch import RynnecMetaForCausalLM, RynnecMetaModel
34
+ from .videollama3_encoder import Videollama3ImageProcessor
35
+ from rynnec.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN
36
+ from .sam2_train import SAM2TrainRunner
37
+ from .sam2 import SAM2
38
+ from .utils import genetate_video_pred_embeddings, process_video_gt_masks
39
+
40
+ CHAT_TEMPLATE = """
41
+ {%- set identifier = 'im' %}
42
+ {% for message in messages %}
43
+ {% if message['role'] == 'stream' %}
44
+ {% set identifier = 'stream' %}
45
+ {% else %}
46
+ {% set identifier = 'im' %}
47
+ {% endif %}
48
+ {% if message['role'] is not none %}
49
+ {{- '<|' + identifier + '_start|>' + message['role'] + '\n' -}}
50
+ {% endif %}
51
+ {% if message['content'] is string %}
52
+ {{- message['content'] + '<|' + identifier + '_end|>\n' -}}
53
+ {% else %}
54
+ {% for content in message['content'] %}
55
+ {% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}
56
+ {% if 'time' in content %}
57
+ {{- 'Time ' + content['time'] | round(1) | string + 's: ' -}}
58
+ {% endif %}
59
+ {{- image_token + '\n' -}}
60
+ {% elif content['type'] == 'video' or 'video' in content or 'video_url' in content %}
61
+ {% for i in range(content['num_frames']) %}
62
+ {% if 'timestamps' in content %}
63
+ {{- 'Time ' + content['timestamps'][i] | round(1) | string + 's:' -}}
64
+ {% endif %}
65
+ {% if i < content['num_frames'] - 1 %}
66
+ {{- image_token + ',' -}}
67
+ {% else %}
68
+ {{- image_token + '\n' -}}
69
+ {% endif %}
70
+ {% endfor %}
71
+ {% elif content['type'] == 'text' or 'text' in content %}
72
+ {{- content['text'] -}}
73
+ {% endif %}
74
+ {% endfor %}
75
+ {% if message['role'] is not none %}
76
+ {{- '<|' + identifier + '_end|>\n' -}}
77
+ {% endif %}
78
+ {% endif %}
79
+ {% endfor %}
80
+ {% if add_generation_prompt %}
81
+ {{- '<|im_start|>assistant\n' -}}
82
+ {% if add_think_prompt %}
83
+ {{- '<think>\n' -}}
84
+ {% endif %}
85
+ {% endif %}
86
+ """
87
+ @dataclass
88
+ class CausalLMOutputWithPast(ModelOutput):
89
+ loss: Optional[torch.FloatTensor] = None
90
+ logits: torch.FloatTensor = None
91
+ past_key_values: Optional[List[torch.FloatTensor]] = None
92
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
93
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
94
+ rope_deltas: Optional[torch.LongTensor] = None
95
+ ce_loss: Optional[torch.FloatTensor] = None
96
+ mask_bce_loss: Optional[torch.FloatTensor] = None
97
+ mask_dice_loss: Optional[torch.FloatTensor] = None
98
+ mask_loss: Optional[torch.FloatTensor] = None
99
+
100
+
101
+ class Videollama3Qwen2Processor(Videollama3BaseProcessor):
102
+
103
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
104
+ chat_template = CHAT_TEMPLATE
105
+
106
+ def __init__(
107
+ self,
108
+ image_processor=None,
109
+ tokenizer=None,
110
+ chat_template=None,
111
+ image_merge_size: int = 1,
112
+ video_merge_size: int = 2,
113
+ fps=1,
114
+ max_frames=180,
115
+ **kwargs
116
+ ):
117
+ super().__init__(image_processor, tokenizer, chat_template, **kwargs)
118
+ self.generation_prompt = self._infer_generation_prompt()
119
+ self.generation_prompt_ids = self.tokenizer.encode(self.generation_prompt, return_tensors="pt")
120
+ self.generation_prompt_length = len(self.generation_prompt_ids[0])
121
+
122
+ def _infer_generation_prompt(self):
123
+ pseudo_message = [{"role": "user", "content": ""}]
124
+ instruction = self.apply_chat_template(pseudo_message, tokenize=False, add_generation_prompt=True)
125
+ conversation = self.apply_chat_template(pseudo_message, tokenize=False, add_generation_prompt=False)
126
+ return instruction.replace(conversation, "")
127
+
128
+ def _process_text_with_label(
129
+ self,
130
+ text: List[Dict],
131
+ grid_sizes: torch.Tensor = None,
132
+ **kwargs,
133
+ ):
134
+ assert kwargs.pop("return_tensors", "pt") == "pt", "Only PyTorch tensors are supported when return_labels=True."
135
+ assert isinstance(text[0], dict), "When return_labels=True, text must be a list of messages."
136
+
137
+ input_ids_list = []
138
+ targets_list = []
139
+ image_idx = 0
140
+
141
+ for message_idx, message in enumerate(text):
142
+ # 1. set chat template and append image tokens
143
+ prompt = self.apply_chat_template([message], tokenize=False, add_generation_prompt=False)
144
+ prompt_chunks = prompt.split(DEFAULT_IMAGE_TOKEN)
145
+ prompt = []
146
+ for chunk_idx in range(len(prompt_chunks) - 1):
147
+ prompt.append(prompt_chunks[chunk_idx])
148
+ thw = grid_sizes[image_idx]
149
+ prompt.append(DEFAULT_IMAGE_TOKEN * thw.prod().long())
150
+ image_idx += 1
151
+ prompt.append(prompt_chunks[-1])
152
+ prompt = "".join(prompt)
153
+
154
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt")[0]
155
+ input_ids_list.append(input_ids)
156
+
157
+ targets = torch.full_like(input_ids, IGNORE_INDEX)
158
+ if message["role"] == "assistant" or message["role"] is None:
159
+ targets[self.generation_prompt_length:-1] = input_ids[self.generation_prompt_length:-1].clone()
160
+
161
+ # NOTE: mask out image tokens
162
+ vision_mask = input_ids == self.image_token_id
163
+ targets[vision_mask] = IGNORE_INDEX
164
+ vision_indices = torch.nonzero(vision_mask, as_tuple=True)[0]
165
+ targets[vision_indices + 1] = IGNORE_INDEX
166
+
167
+ # NOTE: mask out <think> or <think>\n
168
+ think_mask = targets == self.think_start_token_id
169
+ targets[think_mask] = IGNORE_INDEX
170
+ think_indices = torch.nonzero(think_mask, as_tuple=True)[0]
171
+ newline_mask = torch.zeros_like(think_mask)
172
+ newline_mask[think_indices + 1] = targets[think_indices + 1] == self.newline_token_id
173
+ targets[newline_mask] = IGNORE_INDEX
174
+
175
+ targets_list.append(targets)
176
+
177
+ assert len(grid_sizes) == image_idx, "Number of images does not match the number of image tokens in the text."
178
+
179
+ text_inputs = {
180
+ "input_ids": torch.cat(input_ids_list),
181
+ "labels": torch.cat(targets_list),
182
+ }
183
+
184
+ return text_inputs
185
+
186
+
187
+ class RynnecQwen2Config(Qwen2Config):
188
+ model_type = "rynnec_qwen2"
189
+
190
+ def __init__(self, **kwargs):
191
+ super().__init__(**kwargs)
192
+ self.model_type = "rynnec_qwen2"
193
+
194
+
195
+ class RynnecQwen2Model(RynnecMetaModel, Qwen2Model):
196
+ config_class = RynnecQwen2Config
197
+
198
+ def __init__(self, config: RynnecQwen2Config):
199
+ super(RynnecQwen2Model, self).__init__(config)
200
+
201
+ if hasattr(config, "mm_mask_decoder"): # inference
202
+ self.build_mask_decoder(config)
203
+ else: # training
204
+ if 'out_dim' not in config:
205
+ config.out_dim = 256
206
+
207
+ def build_mask_decoder(self, config):
208
+
209
+ # Projection layer for lisa
210
+ in_dim = config.hidden_size
211
+ out_dim = config.out_dim
212
+ text_fc = [
213
+ nn.Linear(in_dim, in_dim),
214
+ nn.ReLU(inplace=True),
215
+ nn.Linear(in_dim, out_dim),
216
+ nn.Dropout(0.0),
217
+ ]
218
+ self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)])
219
+ self.text_hidden_fcs.train()
220
+ for param in self.text_hidden_fcs.parameters():
221
+ param.requires_grad = True
222
+
223
+
224
+ class RynnecQwen2ForCausalLM(Qwen2ForCausalLM, RynnecMetaForCausalLM):
225
+ config_class = RynnecQwen2Config
226
+
227
+ def __init__(self, config, **kwargs):
228
+ super(Qwen2ForCausalLM, self).__init__(config)
229
+ self.model = RynnecQwen2Model(config)
230
+ self.vocab_size = config.vocab_size
231
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
232
+
233
+ # Initialize weights and apply final processing
234
+ self.post_init()
235
+
236
+ if hasattr(config, "training") and config.training is True:
237
+ self.grounding_encoder = SAM2TrainRunner(ckpt_path=config.mask_decoder_model)
238
+ config.mm_mask_decoder = True
239
+ else:
240
+ self.grounding_encoder = SAM2(ckpt_path=config.mask_decoder_model)
241
+
242
+ self.loss_mask = CrossEntropyLoss(
243
+ use_sigmoid=True,
244
+ reduction='mean',
245
+ loss_weight=2.0
246
+ )
247
+ self.loss_dice = DiceLoss(
248
+ use_sigmoid=True,
249
+ activate=True,
250
+ reduction='mean',
251
+ naive_dice=True,
252
+ eps=1.0,
253
+ loss_weight=0.5
254
+ )
255
+
256
+ def load_sam2_weights(self, model_path):
257
+ sam2_model = torch.load(model_path, map_location='cpu')['model']
258
+ prefix = "sam2_model."
259
+ new_state_dict = {}
260
+ for param_name in sam2_model.keys():
261
+ new_param_name = prefix + param_name
262
+ new_state_dict[new_param_name] = sam2_model[param_name]
263
+
264
+ self.grounding_encoder.load_state_dict(new_state_dict, strict=False)
265
+
266
+ def get_model(self):
267
+ return self.model
268
+ # NOTE: arguments are copied from transformers==4.46.3
269
+ def forward(
270
+ self,
271
+ input_ids: torch.LongTensor = None,
272
+ attention_mask: Optional[torch.Tensor] = None,
273
+ position_ids: Optional[torch.LongTensor] = None,
274
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
275
+ inputs_embeds: Optional[torch.FloatTensor] = None,
276
+ labels: Optional[torch.LongTensor] = None,
277
+ use_cache: Optional[bool] = None,
278
+ output_attentions: Optional[bool] = None,
279
+ output_hidden_states: Optional[bool] = None,
280
+ return_dict: Optional[bool] = None,
281
+ cache_position: Optional[torch.LongTensor] = None,
282
+ num_logits_to_keep: int = 0,
283
+ # multimodal inputs
284
+ pixel_values: Optional[torch.FloatTensor] = None,
285
+ grid_sizes: Optional[torch.LongTensor] = None,
286
+ merge_sizes: Optional[torch.LongTensor] = None,
287
+ modals: Optional[List[str]] = None,
288
+ masks: Optional[List[torch.LongTensor]] = None,
289
+ mask_ids = None,
290
+ sam_images = None,
291
+ sam_size = None,
292
+ image2maskids = None,
293
+ **loss_kwargs,
294
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
295
+ torch.cuda.empty_cache()
296
+ if inputs_embeds is None:
297
+ input_ids_raw = input_ids.clone()
298
+ (
299
+ input_ids,
300
+ attention_mask,
301
+ position_ids,
302
+ past_key_values,
303
+ inputs_embeds,
304
+ labels,
305
+ ) = self.prepare_inputs_labels_for_multimodal(
306
+ input_ids=input_ids,
307
+ attention_mask=attention_mask,
308
+ position_ids=position_ids,
309
+ past_key_values=past_key_values,
310
+ labels=labels,
311
+ pixel_values=pixel_values,
312
+ grid_sizes=grid_sizes,
313
+ merge_sizes=merge_sizes,
314
+ modals=modals,
315
+ masks=masks,
316
+ mask_ids=mask_ids
317
+ )
318
+
319
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
320
+ output_hidden_states = (
321
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
322
+ )
323
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
324
+
325
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
326
+ outputs = self.model(
327
+ input_ids=input_ids,
328
+ attention_mask=attention_mask,
329
+ position_ids=position_ids,
330
+ past_key_values=past_key_values,
331
+ inputs_embeds=inputs_embeds,
332
+ use_cache=use_cache,
333
+ output_attentions=output_attentions,
334
+ output_hidden_states=output_hidden_states,
335
+ return_dict=return_dict,
336
+ cache_position=cache_position,
337
+ )
338
+
339
+ hidden_states = outputs[0]
340
+ loss, logits = None, None
341
+ _valid = True
342
+ seg_valid = True
343
+
344
+ if labels is not None: #training
345
+
346
+ ce_loss = cross_entropy_loss(
347
+ hidden_states=hidden_states,
348
+ lm_head=self.lm_head,
349
+ position_ids=position_ids,
350
+ labels=labels,
351
+ reduction_scope=self.config.loss_reduction_scope,
352
+ **loss_kwargs,
353
+ )
354
+
355
+ if self.config.has_mask:
356
+
357
+ hidden_states_sam = []
358
+ hidden_states_sam.append(self.model.text_hidden_fcs[0](hidden_states))
359
+ hidden_states_sam = torch.stack(hidden_states_sam, dim=-1).sum(dim=-1)
360
+
361
+ bs = input_ids_raw.shape[0]
362
+ gt_masks_list = []
363
+ pred_masks_list = []
364
+ mask_bce_loss = 0
365
+ mask_dice_loss = 0
366
+ num_masks = 0
367
+ for i in range(bs):
368
+ pred_masks = []
369
+ pred_embeddings = []
370
+ input_id = input_ids_raw[i]
371
+ seg_token_mask = input_id[1:]==self.config.seg_token_index
372
+ seg_token_mask = torch.cat(
373
+ [
374
+ seg_token_mask,
375
+ torch.zeros((1)).bool().cuda(),
376
+ ],
377
+ dim=0,
378
+ )
379
+
380
+ pred_embedding = hidden_states_sam[i][seg_token_mask]
381
+ if len(pred_embedding)>0:
382
+ pred_embeddings.append(pred_embedding)
383
+ else:
384
+ pred_embeddings.append(hidden_states_sam[i, :1])
385
+
386
+
387
+ gt_masks_video = [] # FIXME: Only support one segmentation now
388
+ gt_mask = masks[i]
389
+ mask_valid = True
390
+
391
+ if len(image2maskids[i])==0:
392
+ sam_images[i] = sam_images[i][:1]
393
+ gt_masks_video.append(torch.zeros((len(sam_images[i]), 224, 224)).to(sam_images[0].device))
394
+ mask_valid = False
395
+
396
+ else:
397
+ for mids in image2maskids[i]:
398
+ for mid in mids:
399
+ if mid is None:
400
+ gt_masks_video.append(torch.zeros((224, 224)).unsqueeze(0).to(gt_mask[0].device))
401
+ else:
402
+ gt_masks_video.append(gt_mask[mid].unsqueeze(0))
403
+ frames_per_batch = [len(sam_images[i])]
404
+ try:
405
+ pred_embeddings_list_video = genetate_video_pred_embeddings(pred_embeddings, frames_per_batch)
406
+
407
+ # pred_embeddings_list_video, gt_masks_video = check_obj_number(pred_embeddings_list_video, gt_masks_video)
408
+
409
+ g_pixel_values = sam_images[i]
410
+ num_objs = len(pred_embeddings_list_video[0])
411
+
412
+ # with torch.no_grad():
413
+
414
+ sam_states = self.grounding_encoder.get_sam2_embeddings(g_pixel_values, expand_size=num_objs)
415
+ language_embeddings = torch.cat(pred_embeddings_list_video, dim=0)[:, None]#.contiguous()
416
+
417
+ num_frames = len(pred_embeddings_list_video)
418
+ gt_masks_video = process_video_gt_masks(gt_masks_video, num_frames, num_objs)
419
+ pred_masks = self.grounding_encoder.inject_language_embd(sam_states, language_embeddings, nf_nobj=(num_frames, num_objs))
420
+
421
+ gt_masks = [F.interpolate(gt_mask.unsqueeze(0), size=pred_masks[0].shape[-2:], mode='nearest').squeeze(0) for gt_mask in gt_masks_video]
422
+ gt_masks = torch.cat(gt_masks, dim=0)
423
+ pred_masks = pred_masks.flatten(0, 1)
424
+
425
+ if not mask_valid:
426
+ pred_masks = pred_masks*0.0
427
+
428
+ if len(pred_masks) != len(gt_masks):
429
+ # drop this data
430
+ print(f"Pred mask shape {pred_masks.shape} is not equal to gt_mask shape {gt_masks.shape} !!!")
431
+ min_num = min(len(pred_masks), len(gt_masks))
432
+ pred_masks = pred_masks[:min_num]
433
+ gt_masks = gt_masks[:min_num]
434
+ seg_valid = False
435
+
436
+ if not seg_valid or not mask_valid:
437
+ _scale = 0.0
438
+ else:
439
+ _scale = 1.0
440
+
441
+ mask_bce_loss_ = self.loss_mask(pred_masks, gt_masks) * len(pred_masks) * _scale
442
+ mask_dice_loss_ = self.loss_dice(pred_masks, gt_masks) * len(gt_masks) * _scale
443
+ mask_bce_loss += mask_bce_loss_
444
+ mask_dice_loss += mask_dice_loss_
445
+ num_masks += len(pred_masks)
446
+ except Exception as exp:
447
+ print(exp)
448
+ _valid = False
449
+
450
+
451
+ if num_masks>0:
452
+ mask_bce_loss = mask_bce_loss / num_masks
453
+ mask_dice_loss = mask_dice_loss / num_masks
454
+
455
+ mask_bce_loss = self.config.bce_loss_weight * mask_bce_loss
456
+ mask_dice_loss = self.config.dice_loss_weight * mask_dice_loss
457
+ if _valid==False:
458
+ mask_bce_loss = mask_bce_loss * 0.0
459
+ mask_dice_loss = mask_dice_loss* 0.0
460
+
461
+ mask_loss = mask_bce_loss + mask_dice_loss
462
+ loss = mask_loss + ce_loss
463
+ else:
464
+ loss = ce_loss
465
+
466
+ else:
467
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
468
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
469
+
470
+ if not return_dict:
471
+ output = (logits,) + outputs[1:]
472
+ return (loss,) + output if loss is not None else output
473
+
474
+ if loss is not None:
475
+ if self.config.has_mask:
476
+ return CausalLMOutputWithPast(
477
+ loss=loss,
478
+ ce_loss=ce_loss.detach(),
479
+ mask_bce_loss=mask_bce_loss.detach(),
480
+ mask_dice_loss=mask_dice_loss.detach(),
481
+ mask_loss=mask_loss.detach(),
482
+ logits=logits,
483
+ past_key_values=outputs.past_key_values,
484
+ hidden_states=outputs.hidden_states,
485
+ attentions=outputs.attentions,
486
+ )
487
+ else:
488
+ return CausalLMOutputWithPast(
489
+ loss=loss,
490
+ logits=logits,
491
+ past_key_values=outputs.past_key_values,
492
+ hidden_states=outputs.hidden_states,
493
+ attentions=outputs.attentions,
494
+ )
495
+ else: #infer
496
+ return CausalLMOutputWithPast(
497
+ loss=loss,
498
+ logits=logits,
499
+ past_key_values=outputs.past_key_values,
500
+ hidden_states=outputs.hidden_states,
501
+ attentions=outputs.attentions,
502
+ )
503
+
504
+ @torch.no_grad()
505
+ def inference(
506
+ self,
507
+ # multimodal inputs
508
+ pixel_values: Optional[torch.FloatTensor] = None,
509
+ grid_sizes: Optional[torch.LongTensor] = None,
510
+ merge_sizes: Optional[torch.LongTensor] = None,
511
+ modals: Optional[List[str]] = None,
512
+ masks: Optional[List[torch.LongTensor]] = None,
513
+ mask_ids = None,
514
+ sam_images = None,
515
+ sam_size = None,
516
+ image2maskids = None,
517
+ seg_start_idx = 0,
518
+ **kwargs,
519
+ ):
520
+ outputs = self.generate(
521
+ pixel_values=pixel_values,
522
+ grid_sizes=grid_sizes,
523
+ merge_sizes=merge_sizes,
524
+ modals=modals,
525
+ masks=masks,
526
+ mask_ids=mask_ids,
527
+ output_hidden_states=True,
528
+ return_dict_in_generate=True,
529
+ **kwargs
530
+ )
531
+
532
+ input_ids = kwargs.pop('input_ids')
533
+ last_hidden_state = []
534
+ for hs in outputs.hidden_states: # round
535
+ last_hidden_state.append(hs[-1])
536
+ last_hidden_state = torch.cat(last_hidden_state, dim=1)
537
+
538
+ output_ids = outputs.sequences
539
+
540
+ concat_ids = torch.cat((input_ids, output_ids), dim=1)
541
+ seg_token_mask = concat_ids[:, 1:] == self.config.seg_token_index
542
+
543
+ last_hidden_state_sam = self.model.text_hidden_fcs[0](last_hidden_state)
544
+
545
+ pred_embeddings = last_hidden_state_sam[seg_token_mask]
546
+ seg_token_counts = seg_token_mask.int().sum()
547
+
548
+ if seg_token_counts>0:
549
+
550
+ g_pixel_values = torch.cat(sam_images, dim=0).contiguous()
551
+ num_objs = 1 #FIXME: Only support one segmentation now
552
+ if seg_start_idx>0:
553
+ # before start idx
554
+ g_pixel_values_beg = g_pixel_values[:seg_start_idx+1].flip(0)
555
+ num_frames = len(g_pixel_values_beg)
556
+ sam_states_beg = self.grounding_encoder.get_sam2_embeddings(g_pixel_values_beg)
557
+ pred_masks_beg = self.grounding_encoder.language_embd_inference(sam_states_beg, [pred_embeddings]*num_frames)
558
+ else:
559
+ pred_masks_beg = torch.zeros((1, 1, 1024, 1024)).to(pixel_values.device)
560
+
561
+ if seg_start_idx<=len(g_pixel_values)-1:
562
+ g_pixel_values_end = g_pixel_values[seg_start_idx:]
563
+ num_frames = len(g_pixel_values_end)
564
+ sam_states_end = self.grounding_encoder.get_sam2_embeddings(g_pixel_values_end)
565
+ pred_masks_end = self.grounding_encoder.language_embd_inference(sam_states_end, [pred_embeddings]*num_frames)
566
+ else:
567
+ pred_masks_end = torch.zeros((0, 1, 1024, 1024)).to(pixel_values.device)
568
+
569
+ pred_masks = torch.cat([pred_masks_beg[1:].flip(0), pred_masks_end], dim=0)
570
+
571
+ return output_ids, pred_masks
572
+
573
+
574
+ @torch.no_grad()
575
+ def generate(
576
+ self,
577
+ # multimodal inputs
578
+ pixel_values: Optional[torch.FloatTensor] = None,
579
+ grid_sizes: Optional[torch.LongTensor] = None,
580
+ merge_sizes: Optional[torch.LongTensor] = None,
581
+ modals: Optional[List[str]] = None,
582
+ masks: Optional[List[torch.LongTensor]] = None,
583
+ mask_ids = None,
584
+ **kwargs,
585
+ ) -> Union[GenerateOutput, torch.LongTensor]:
586
+ input_ids = kwargs.pop("input_ids", None)
587
+ attention_mask = kwargs.pop("attention_mask", None)
588
+ position_ids = kwargs.pop("position_ids", None)
589
+ past_key_values = kwargs.pop("past_key_values", None)
590
+
591
+ if "inputs_embeds" in kwargs:
592
+ raise NotImplementedError("`inputs_embeds` is not supported")
593
+
594
+ if pixel_values is not None:
595
+ (
596
+ input_ids,
597
+ attention_mask,
598
+ position_ids,
599
+ past_key_values,
600
+ inputs_embeds,
601
+ labels,
602
+ ) = self.prepare_inputs_labels_for_multimodal(
603
+ input_ids=input_ids,
604
+ attention_mask=attention_mask,
605
+ position_ids=position_ids,
606
+ past_key_values=past_key_values,
607
+ labels=None,
608
+ pixel_values=pixel_values,
609
+ grid_sizes=grid_sizes,
610
+ merge_sizes=merge_sizes,
611
+ modals=modals,
612
+ masks=masks,
613
+ mask_ids=mask_ids
614
+ )
615
+ else:
616
+ inputs_embeds = self.get_model().embed_tokens(input_ids)
617
+
618
+ return super().generate(
619
+ position_ids=position_ids,
620
+ attention_mask=attention_mask,
621
+ inputs_embeds=inputs_embeds,
622
+ **kwargs
623
+ )
624
+
625
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
626
+ images = kwargs.pop("images", None)
627
+ _inputs = super().prepare_inputs_for_generation(
628
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
629
+ )
630
+ if images is not None:
631
+ _inputs['images'] = images
632
+ return _inputs
633
+
634
+
635
+ AutoConfig.register("rynnec_qwen2", RynnecQwen2Config)
636
+ AutoModelForCausalLM.register(RynnecQwen2Config, RynnecQwen2ForCausalLM)
637
+ AutoProcessor.register(RynnecQwen2Config, Videollama3Qwen2Processor)
638
+ AutoImageProcessor.register(RynnecQwen2Config, Videollama3ImageProcessor)
rynnec/model/sam2.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/magic-research/Sa2VA/blob/main/projects/llava_sam2/models/sam2.py.
2
+ # Below is the original copyright:
3
+ # coding=utf-8
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import os.path
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from hydra import compose
23
+ from hydra.utils import instantiate
24
+ from omegaconf import OmegaConf
25
+
26
+ from .utils import load_checkpoint_with_prefix, load_state_dict_to_model
27
+
28
+ class SAM2(nn.Module):
29
+ def __init__(
30
+ self,
31
+ cfg_path: str = "sam2_hiera_l.yaml",
32
+ ckpt_path: str = "sam2_hiera_large.pt",
33
+ hydra_overrides_extra=None,
34
+ apply_postprocessing=True,
35
+ ):
36
+ super().__init__()
37
+
38
+ import third_parts.sam2 # noqa: F401
39
+
40
+ if hydra_overrides_extra is None:
41
+ hydra_overrides_extra = []
42
+ hydra_overrides = [
43
+ ## Extension: LLM prompt
44
+ "++model._target_=rynnec.model.predictor.SAM2VideoPredictor",
45
+ ]
46
+
47
+ if apply_postprocessing:
48
+ hydra_overrides_extra = hydra_overrides_extra.copy()
49
+ hydra_overrides_extra += [
50
+ # dynamically fall back to multi-mask if the single mask is not stable
51
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
52
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
53
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
54
+ # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
55
+ # "++model.binarize_mask_from_pts_for_mem_enc=true",
56
+ # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
57
+ # "++model.fill_hole_area=8",
58
+ ]
59
+ hydra_overrides.extend(hydra_overrides_extra)
60
+
61
+ # Read config and init model
62
+ cfg = compose(config_name=cfg_path, overrides=hydra_overrides)
63
+ OmegaConf.resolve(cfg)
64
+ sam2_model = instantiate(cfg.model, _recursive_=True)
65
+ state_dict = load_checkpoint_with_prefix(ckpt_path)
66
+ load_state_dict_to_model(sam2_model, state_dict)
67
+
68
+ self.sam2_model = sam2_model
69
+
70
+ self.hidden_dim = self.sam2_model.hidden_dim
71
+
72
+ self.img_mean = (0.485, 0.456, 0.406)
73
+ self.img_std = (0.229, 0.224, 0.225)
74
+
75
+ def inject_language_embd(self, inference_state, language_embd):
76
+ num_frame = len(language_embd)
77
+ num_obj = len(language_embd[0])
78
+ mask_out = []
79
+ for frame_idx in range(num_frame):
80
+ frame_mask_out = []
81
+ for obj_idx in range(num_obj):
82
+ _language_embd = language_embd[frame_idx][obj_idx][None][None]
83
+ _, _, out_mask_logits = self.sam2_model.add_language_embd(inference_state, frame_idx, obj_idx + 100, _language_embd)
84
+ frame_mask_out.append(out_mask_logits)
85
+ frame_mask_out = torch.cat(frame_mask_out, dim=1)
86
+ mask_out.append(frame_mask_out)
87
+ mask_out = torch.cat(mask_out, dim=0)
88
+ return mask_out
89
+
90
+
91
+ def language_embd_inference(self, inference_state, language_embd):
92
+ num_frame = len(language_embd)
93
+ num_obj = len(language_embd[0])
94
+ mask_out = []
95
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
96
+ for frame_idx in range(num_frame):
97
+ frame_mask_out = []
98
+
99
+ for obj_idx in range(num_obj):
100
+ _language_embd = language_embd[frame_idx][obj_idx][None][None]
101
+ _, _, out_mask_logits = self.sam2_model.add_language_embd(
102
+ inference_state,
103
+ frame_idx,
104
+ obj_idx + 100,
105
+ _language_embd,
106
+ inference=True,
107
+ )
108
+ frame_mask_out.append(out_mask_logits)
109
+ frame_mask_out = torch.cat(frame_mask_out, dim=1)
110
+ mask_out.append(frame_mask_out)
111
+
112
+
113
+ mask_out = []
114
+ for out_frame_idx, out_obj_ids, out_mask_logits in self.sam2_model.propagate_in_video(inference_state):
115
+ mask_out.append(out_mask_logits)
116
+ mask_out = torch.cat(mask_out, dim=0)
117
+ return mask_out
118
+
119
+ def get_sam2_embeddings(self, images):
120
+ return self.sam2_model.init_state(images)
121
+
122
+ def forward(self, batch):
123
+ raise NotImplementedError
124
+
125
+ def preprocess_image(self, image: torch.Tensor, dtype=torch.float32) -> torch.Tensor:
126
+ image = image / 255.
127
+
128
+ img_mean = torch.tensor(self.img_mean, dtype=dtype, device=image.device)[:, None, None]
129
+ img_std = torch.tensor(self.img_std, dtype=dtype, device=image.device)[:, None, None]
130
+ image -= img_mean
131
+ image /= img_std
132
+
133
+ return image
rynnec/model/sam2_train.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/magic-research/Sa2VA/blob/main/projects/llava_sam2/models/sam2_train.py.
2
+ # Below is the original copyright:
3
+ # coding=utf-8
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import os.path
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ from hydra import compose
24
+ from hydra.utils import instantiate
25
+ from omegaconf import OmegaConf
26
+
27
+ from .utils import load_checkpoint_with_prefix, load_state_dict_to_model
28
+
29
+ BASE_DIR = 'pretrained/'
30
+
31
+
32
+ class SAM2TrainRunner(nn.Module):
33
+ def __init__(
34
+ self,
35
+ cfg_path: str = "sam2_hiera_l.yaml",
36
+ ckpt_path: str = "sam2_hiera_large.pt",
37
+ hydra_overrides_extra=None,
38
+ apply_postprocessing=True,
39
+ ):
40
+ super().__init__()
41
+
42
+ import third_parts.sam2 # noqa: F401
43
+
44
+ if hydra_overrides_extra is None:
45
+ hydra_overrides_extra = []
46
+ hydra_overrides = [
47
+ ## Extension: LLM prompt
48
+ "++model._target_=rynnec.model.extension.SAM2Base",
49
+ ]
50
+
51
+ if apply_postprocessing:
52
+ hydra_overrides_extra = hydra_overrides_extra.copy()
53
+
54
+ hydra_overrides.extend(hydra_overrides_extra)
55
+
56
+ # Read config and init model
57
+ cfg = compose(config_name=cfg_path, overrides=hydra_overrides)
58
+ OmegaConf.resolve(cfg)
59
+ sam2_model = instantiate(cfg.model, _recursive_=True)
60
+ state_dict = load_checkpoint_with_prefix(ckpt_path)
61
+ load_state_dict_to_model(sam2_model, state_dict)
62
+
63
+ self.sam2_model = sam2_model
64
+
65
+ self.hidden_dim = self.sam2_model.hidden_dim
66
+ self.img_mean = (0.485, 0.456, 0.406)
67
+ self.img_std = (0.229, 0.224, 0.225)
68
+
69
+ def preprocess_image(self, image: torch.Tensor) -> torch.Tensor:
70
+ image = image / 255.
71
+ img_mean = torch.tensor(self.img_mean, dtype=image.dtype, device=image.device)[:, None, None]
72
+ img_std = torch.tensor(self.img_std, dtype=image.dtype, device=image.device)[:, None, None]
73
+ image -= img_mean
74
+ image /= img_std
75
+ return image
76
+
77
+ def inject_language_embd(self, sam_states, language_embd, nf_nobj=None):
78
+ high_res_features = [
79
+ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
80
+ for x, s in zip(sam_states['current_vision_feats'][:-1], sam_states['feat_sizes'][:-1])
81
+ ]
82
+
83
+ B = sam_states['current_vision_feats'][-1].size(1) # batch size on this frame
84
+ C = self.hidden_dim
85
+ H, W = sam_states['feat_sizes'][-1]
86
+
87
+ if self.sam2_model.directly_add_no_mem_embed:
88
+ # directly add no-mem embedding (instead of using the transformer encoder)
89
+ pix_feat_with_mem = sam_states['current_vision_feats'][-1] + self.sam2_model.no_mem_embed
90
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
91
+ else:
92
+ raise NotImplementedError("directly add no memory embedding is not implemented")
93
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
94
+ _, _, _, low_res_masks, high_res_masks, obj_ptr, _, = self.sam2_model._forward_sam_heads(
95
+ backbone_features=pix_feat_with_mem,
96
+ point_inputs=None,
97
+ mask_inputs=None,
98
+ high_res_features=high_res_features,
99
+ multimask_output=self.sam2_model._use_multimask(is_init_cond_frame=True, point_inputs=None),
100
+ # Inject language Embed if possible
101
+ language_embd=language_embd,
102
+ )
103
+
104
+ if nf_nobj is not None:
105
+ pred_masks = low_res_masks.squeeze(1)
106
+ pred_masks = pred_masks.unflatten(0, nf_nobj)
107
+ else:
108
+ pred_masks = low_res_masks
109
+ return pred_masks
110
+
111
+ def get_sam2_embeddings(self, images, expand_size=1):
112
+ # Step 1: inference the backbone with the images
113
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
114
+ feats = self.sam2_model.forward_image(images)
115
+
116
+ if expand_size > 1:
117
+ # feats['vision_features'] = feats['vision_features'][:, None].expand(-1, expand_size, -1, -1, -1).flatten(0, 1)
118
+ for i, feat in enumerate(feats["backbone_fpn"]):
119
+ feats["backbone_fpn"][i] = feat[:, None].expand(-1, expand_size, -1, -1, -1).flatten(0, 1)
120
+ for i, pos in enumerate(feats["vision_pos_enc"]):
121
+ pos = pos[:, None].expand(-1, expand_size, -1, -1, -1).flatten(0, 1)
122
+ feats["vision_pos_enc"][i] = pos
123
+
124
+ # Step 2: Process the features to output
125
+ _, current_vision_feats, current_vision_pos_embeds, feat_sizes = self.sam2_model._prepare_backbone_features(feats)
126
+
127
+ return {
128
+ "current_vision_feats": current_vision_feats,
129
+ "current_vision_pos_embeds": current_vision_pos_embeds,
130
+ "feat_sizes": feat_sizes,
131
+ }
132
+
133
+ def forward(self, batch):
134
+ raise NotImplementedError
rynnec/model/utils.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch import Tensor
5
+ import logging
6
+ from huggingface_hub import hf_hub_download
7
+ import functools
8
+ from typing import Callable, Optional
9
+
10
+ def process_video_gt_masks(gt_masks, num_frames, num_objs):
11
+ gt_masks_processed = []
12
+ for i in range(num_frames):
13
+ for j in range(num_objs):
14
+ gt_masks_processed.append(gt_masks[j*num_frames+i])
15
+ return gt_masks_processed
16
+
17
+ def load_checkpoint_with_prefix(filename, prefix=None, map_location='cpu', logger='current'):
18
+ HF_HUB_PREFIX = 'hf-hub:'
19
+ if filename.startswith(HF_HUB_PREFIX):
20
+ model_id = filename[len(HF_HUB_PREFIX):]
21
+ filename = hf_hub_download(model_id, 'pytorch_model.bin')
22
+
23
+ checkpoint = torch.load(filename, map_location=map_location)
24
+
25
+ if 'state_dict' in checkpoint:
26
+ state_dict = checkpoint['state_dict']
27
+ elif 'model' in checkpoint:
28
+ state_dict = checkpoint['model']
29
+ else:
30
+ state_dict = checkpoint
31
+ if not prefix:
32
+ return state_dict
33
+ if not prefix.endswith('.'):
34
+ prefix += '.'
35
+ prefix_len = len(prefix)
36
+
37
+ state_dict = {
38
+ k[prefix_len:]: v
39
+ for k, v in state_dict.items() if k.startswith(prefix)
40
+ }
41
+
42
+ assert state_dict, f'{prefix} is not in the pretrained model'
43
+ return state_dict
44
+
45
+
46
+ def load_state_dict_to_model(model, state_dict, logger='current'):
47
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict)
48
+ if missing_keys:
49
+ raise RuntimeError()
50
+ if unexpected_keys:
51
+ raise RuntimeError()
52
+
53
+ def genetate_video_pred_embeddings(pred_embeddings_list, frames_per_batch):
54
+ assert len(pred_embeddings_list) == len(frames_per_batch), \
55
+ f"Lengths do not match: len(pred_embeddings_list)={len(pred_embeddings_list)}, len(frames_per_batch)={len(frames_per_batch)}"
56
+
57
+ pred_embeddings_list_video = []
58
+ for pred_embedding_batch, frame_nums in zip(pred_embeddings_list, frames_per_batch):
59
+ pred_embeddings_list_video += [pred_embedding_batch] * frame_nums
60
+ return pred_embeddings_list_video
61
+
rynnec/model/videollama3_encoder/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .configuration_videollama3_encoder import Videollama3VisionEncoderConfig
2
+ from .image_processing_videollama3 import Videollama3ImageProcessor
3
+ from .modeling_videollama3_encoder import Videollama3VisionEncoderModel
rynnec/model/videollama3_encoder/configuration_videollama3_encoder.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/siglip/configuration_siglip.py.
2
+ # Below is the original copyright:
3
+ # coding=utf-8
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """VideoLLaMA3 vision encoder model configuration."""
18
+ import os
19
+ from typing import Union
20
+
21
+ from transformers import PretrainedConfig
22
+ from transformers.utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ class Videollama3VisionEncoderConfig(PretrainedConfig):
28
+
29
+ model_type = "videollama3_vision_encoder"
30
+
31
+ def __init__(
32
+ self,
33
+ hidden_size=768,
34
+ intermediate_size=3072,
35
+ num_hidden_layers=12,
36
+ num_attention_heads=12,
37
+ num_channels=3,
38
+ patch_size=16,
39
+ hidden_act="gelu_pytorch_tanh",
40
+ layer_norm_eps=1e-6,
41
+ attention_dropout=0.0,
42
+ **kwargs,
43
+ ):
44
+ super().__init__(**kwargs)
45
+
46
+ self.hidden_size = hidden_size
47
+ self.intermediate_size = intermediate_size
48
+ self.num_hidden_layers = num_hidden_layers
49
+ self.num_attention_heads = num_attention_heads
50
+ self.num_channels = num_channels
51
+ self.patch_size = patch_size
52
+ self.attention_dropout = attention_dropout
53
+ self.layer_norm_eps = layer_norm_eps
54
+ self.hidden_act = hidden_act
55
+
56
+ # @classmethod
57
+ # def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
58
+ # cls._set_token_in_kwargs(kwargs)
59
+
60
+ # config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
61
+
62
+ # p
63
+ # config_dict = config_dict["vision_config"]
64
+
65
+ # if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
66
+ # logger.warning(
67
+ # f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
68
+ # f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
69
+ # )
70
+
71
+ # return cls.from_dict(config_dict, **kwargs)
rynnec/model/videollama3_encoder/image_processing_videollama3.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py.
2
+ # Below is the original copyright:
3
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ """Image processor class for VideoLLaMA3."""
22
+
23
+ import math
24
+ from typing import Dict, List, Optional, Union
25
+
26
+ import numpy as np
27
+
28
+ import torch
29
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
30
+ from transformers.image_utils import ImageInput
31
+ from transformers.image_transforms import (
32
+ convert_to_rgb,
33
+ resize,
34
+ to_channel_dimension_format,
35
+ )
36
+ from transformers.image_utils import (
37
+ OPENAI_CLIP_MEAN,
38
+ OPENAI_CLIP_STD,
39
+ ChannelDimension,
40
+ ImageInput,
41
+ PILImageResampling,
42
+ VideoInput,
43
+ get_image_size,
44
+ infer_channel_dimension_format,
45
+ is_scaled_image,
46
+ is_valid_image,
47
+ make_list_of_images,
48
+ to_numpy_array,
49
+ )
50
+ from transformers.utils import TensorType, is_vision_available, logging
51
+
52
+
53
+ logger = logging.get_logger(__name__)
54
+
55
+
56
+ if is_vision_available():
57
+ from PIL import Image
58
+
59
+
60
+ def is_valid_video(video) -> bool:
61
+ if isinstance(video, (list, tuple)):
62
+ return all(is_valid_image(frame) for frame in video)
63
+ elif isinstance(video, np.ndarray):
64
+ return video.ndim == 4
65
+ elif isinstance(video, torch.Tensor):
66
+ return video.ndim == 4
67
+ return False
68
+
69
+
70
+ def make_batched_images(images) -> List[List[ImageInput]]:
71
+ """
72
+ Accepts images in list or nested list format, and makes a list of images for preprocessing.
73
+
74
+ Args:
75
+ images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
76
+ The input image.
77
+
78
+ Returns:
79
+ list: A list of images.
80
+ """
81
+ if isinstance(images, (list, tuple)):
82
+ # list of images/videos
83
+ if not all(is_valid_video(image) or is_valid_image(image) for image in images):
84
+ raise ValueError(f"Could not make batched images from {images}")
85
+ return images
86
+ elif is_valid_video(images) or is_valid_image(images):
87
+ # single image/video
88
+ return [images]
89
+
90
+ raise ValueError(f"Could not make batched images from {images}")
91
+
92
+
93
+ def simple_batched_resize(
94
+ images, factor: int = 28, min_tokens: int = 4 * 4, max_tokens: int = 16384, input_data_format: str = None
95
+ ):
96
+ min_pixels = min_tokens * factor * factor
97
+ max_pixels = max_tokens * factor * factor
98
+
99
+ num_images = 0
100
+ for image in images:
101
+ if is_valid_video(image):
102
+ num_images += len(image)
103
+ else:
104
+ num_images += 1
105
+
106
+ image_sizes = []
107
+ for image in images:
108
+ if is_valid_video(image):
109
+ image = image[0]
110
+ if isinstance(image, Image.Image):
111
+ width, height = image.size
112
+ else:
113
+ height, width = get_image_size(image, channel_dim=input_data_format)
114
+ image_sizes.append([height, width])
115
+
116
+ tmp_image_sizes = []
117
+ for height, width in image_sizes:
118
+ h_bar = round(height / factor) * factor
119
+ w_bar = round(width / factor) * factor
120
+ if h_bar * w_bar > (max_pixels // num_images):
121
+ beta = math.sqrt((height * width) / (max_pixels // num_images))
122
+ h_bar = math.floor(height / beta / factor) * factor
123
+ w_bar = math.floor(width / beta / factor) * factor
124
+ # per image min_pixels
125
+ if h_bar * w_bar < min_pixels:
126
+ beta = math.sqrt(min_pixels / (height * width))
127
+ h_bar = math.ceil(height * beta / factor) * factor
128
+ w_bar = math.ceil(width * beta / factor) * factor
129
+ tmp_image_sizes.append((h_bar, w_bar))
130
+ image_sizes = tmp_image_sizes
131
+ return image_sizes
132
+
133
+
134
+ def batched_resize(
135
+ images, factors: List[int], min_tokens: int = 4 * 4, max_tokens: int = 16384, input_data_format: str = None
136
+ ):
137
+ image_sizes = []
138
+ for image in images:
139
+ if is_valid_video(image):
140
+ num_frame = len(image)
141
+ image = image[0]
142
+ else:
143
+ num_frame = 1
144
+ if isinstance(image, Image.Image):
145
+ width, height = image.size
146
+ else:
147
+ height, width = get_image_size(image, channel_dim=input_data_format)
148
+ image_sizes.append([num_frame, height, width])
149
+
150
+ # global max_pixels
151
+ smart_scale_factors = 1.0
152
+ total_tokens = 0
153
+ for (num_frame, height, width), factor in zip(image_sizes, factors):
154
+ total_tokens += num_frame * math.ceil(height / factor) * math.ceil(width / factor)
155
+
156
+ # TODO: add min_pixels
157
+ if total_tokens > max_tokens:
158
+ beta = math.sqrt(total_tokens / max_tokens)
159
+ tmp_image_sizes = []
160
+ for (_, height, width), factor in zip(image_sizes, factors):
161
+ h_bar = math.floor(height / beta / factor) * factor
162
+ w_bar = math.floor(width / beta / factor) * factor
163
+ tmp_image_sizes.append((h_bar, w_bar))
164
+ image_sizes = tmp_image_sizes
165
+ else:
166
+ tmp_image_sizes = []
167
+ for (_, height, width), factor in zip(image_sizes, factors):
168
+ height = round(height / factor) * factor
169
+ width = round(width / factor) * factor
170
+ tmp_image_sizes.append((height, width))
171
+ image_sizes = tmp_image_sizes
172
+
173
+ return image_sizes
174
+
175
+
176
+ class Videollama3ImageProcessor(BaseImageProcessor):
177
+ r"""
178
+ Constructs a DAMOVL image processor that dynamically resizes images based on the original images.
179
+
180
+ Args:
181
+ do_resize (`bool`, *optional*, defaults to `True`):
182
+ Whether to resize the image's (height, width) dimensions.
183
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
184
+ Resampling filter to use when resizing the image.
185
+ do_rescale (`bool`, *optional*, defaults to `True`):
186
+ Whether to rescale the image by the specified scale `rescale_factor`.
187
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
188
+ Scale factor to use if rescaling the image.
189
+ do_normalize (`bool`, *optional*, defaults to `True`):
190
+ Whether to normalize the image.
191
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
192
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
193
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
194
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
195
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
196
+ Whether to convert the image to RGB.
197
+ min_pixels (`int`, *optional*, defaults to `56 * 56`):
198
+ The min pixels of the image to resize the image.
199
+ max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
200
+ The max pixels of the image to resize the image.
201
+ patch_size (`int`, *optional*, defaults to 14):
202
+ The spacial patch size of the vision encoder.
203
+ """
204
+
205
+ model_input_names = ["pixel_values", "grid_sizes", "merge_sizes"]
206
+
207
+ def __init__(
208
+ self,
209
+ do_resize: bool = True,
210
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
211
+ do_rescale: bool = True,
212
+ rescale_factor: Union[int, float] = 1 / 255,
213
+ do_normalize: bool = True,
214
+ image_mean: Optional[Union[float, List[float]]] = None,
215
+ image_std: Optional[Union[float, List[float]]] = None,
216
+ do_convert_rgb: bool = True,
217
+ min_tokens: int = 4 * 4,
218
+ max_tokens: int = 16384,
219
+ patch_size: int = 14,
220
+ **kwargs,
221
+ ) -> None:
222
+ super().__init__(**kwargs)
223
+ self.do_resize = do_resize
224
+ self.resample = resample
225
+ self.do_rescale = do_rescale
226
+ self.rescale_factor = rescale_factor
227
+ self.do_normalize = do_normalize
228
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
229
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
230
+ self.min_tokens = min_tokens
231
+ self.max_tokens = max_tokens
232
+ self.patch_size = patch_size
233
+ self.do_convert_rgb = do_convert_rgb
234
+
235
+ def _preprocess(
236
+ self,
237
+ images: Union[ImageInput, VideoInput],
238
+ target_size: List[int],
239
+ merge_size: int = 1,
240
+ do_resize: bool = None,
241
+ resample: PILImageResampling = None,
242
+ do_rescale: bool = None,
243
+ rescale_factor: float = None,
244
+ do_normalize: bool = None,
245
+ image_mean: Optional[Union[float, List[float]]] = None,
246
+ image_std: Optional[Union[float, List[float]]] = None,
247
+ do_convert_rgb: bool = None,
248
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
249
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
250
+ ):
251
+ """
252
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
253
+
254
+ Args:
255
+ images (`ImageInput`):
256
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
257
+ target_size (`List[int]`):
258
+ The target size to resize the image to. Should be a list of two integers: [target_height, target_width].
259
+ merge_size (`int`, *optional*, defaults to `1`):
260
+ The merge size after the vision encoder.
261
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
262
+ Whether to resize the image.
263
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
264
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
265
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
266
+ Whether to rescale the image.
267
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
268
+ Scale factor to use if rescaling the image.
269
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
270
+ Whether to normalize the image.
271
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
272
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
273
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
274
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
275
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
276
+ Whether to convert the image to RGB.
277
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
278
+ The channel dimension format for the output image. Can be one of:
279
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
280
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
281
+ - Unset: Use the channel dimension format of the input image.
282
+ input_data_format (`ChannelDimension` or `str`, *optional*):
283
+ The channel dimension format for the input image. Can be one of:
284
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
285
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
286
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
287
+ """
288
+ images = make_list_of_images(images)
289
+
290
+ if do_convert_rgb:
291
+ images = [convert_to_rgb(image) for image in images]
292
+
293
+ # All transformations expect numpy arrays.
294
+ images = [to_numpy_array(image) for image in images]
295
+
296
+ if is_scaled_image(images[0]) and do_rescale:
297
+ logger.warning_once(
298
+ "It looks like you are trying to rescale already rescaled images. If the input"
299
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
300
+ )
301
+ if input_data_format is None:
302
+ # We assume that all images have the same channel dimension format.
303
+ input_data_format = infer_channel_dimension_format(images[0])
304
+
305
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
306
+ resized_height, resized_width = height, width
307
+ processed_images = []
308
+ for image in images:
309
+ if do_resize:
310
+ resized_height, resized_width = target_size
311
+ image = resize(
312
+ image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
313
+ )
314
+
315
+ if do_rescale:
316
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
317
+
318
+ if do_normalize:
319
+ image = self.normalize(
320
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
321
+ )
322
+
323
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
324
+ processed_images.append(image)
325
+
326
+ patches = np.array(processed_images)
327
+ if data_format == ChannelDimension.LAST:
328
+ patches = patches.transpose(0, 3, 1, 2)
329
+ t = patches.shape[0]
330
+ channel = patches.shape[1]
331
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
332
+ patches = patches.reshape(
333
+ t,
334
+ channel,
335
+ grid_h // merge_size,
336
+ merge_size,
337
+ self.patch_size,
338
+ grid_w // merge_size,
339
+ merge_size,
340
+ self.patch_size,
341
+ )
342
+ patches = patches.transpose(0, 2, 5, 3, 6, 1, 4, 7)
343
+ flatten_patches = patches.reshape(
344
+ t * grid_h * grid_w, channel * self.patch_size * self.patch_size
345
+ )
346
+
347
+ return flatten_patches, (t, grid_h, grid_w)
348
+
349
+ def preprocess(
350
+ self,
351
+ images: ImageInput,
352
+ do_resize: bool = None,
353
+ resample: PILImageResampling = None,
354
+ do_rescale: bool = None,
355
+ rescale_factor: float = None,
356
+ do_normalize: bool = None,
357
+ image_mean: Optional[Union[float, List[float]]] = None,
358
+ image_std: Optional[Union[float, List[float]]] = None,
359
+ do_convert_rgb: bool = None,
360
+ merge_size: Optional[Union[int, List[int]]] = None,
361
+ return_tensors: Optional[Union[str, TensorType]] = None,
362
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
363
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
364
+ ):
365
+ """
366
+ Args:
367
+ images (`ImageInput`):
368
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
369
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
370
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
371
+ Whether to resize the image.
372
+ resample (`int`, *optional*, defaults to `self.resample`):
373
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
374
+ has an effect if `do_resize` is set to `True`.
375
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
376
+ Whether to rescale the image.
377
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
378
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
379
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
380
+ Whether to normalize the image.
381
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
382
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
383
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
384
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
385
+ `True`.
386
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
387
+ Whether to convert the image to RGB.
388
+ return_tensors (`str` or `TensorType`, *optional*):
389
+ The type of tensors to return. Can be one of:
390
+ - Unset: Return a list of `np.ndarray`.
391
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
392
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
393
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
394
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
395
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
396
+ The channel dimension format for the output image. Can be one of:
397
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
398
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
399
+ - Unset: Use the channel dimension format of the input image.
400
+ input_data_format (`ChannelDimension` or `str`, *optional*):
401
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
402
+ from the input image. Can be one of:
403
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
404
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
405
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
406
+
407
+ """
408
+ do_resize = do_resize if do_resize is not None else self.do_resize
409
+ resample = resample if resample is not None else self.resample
410
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
411
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
412
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
413
+ image_mean = image_mean if image_mean is not None else self.image_mean
414
+ image_std = image_std if image_std is not None else self.image_std
415
+ merge_size = merge_size if merge_size is not None else self.merge_size
416
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
417
+
418
+ images = make_batched_images(images)
419
+
420
+ if isinstance(merge_size, (list, tuple)):
421
+ assert len(merge_size) == len(images), "Merge size must be the same length as images."
422
+ merge_sizes = merge_size
423
+ else:
424
+ merge_sizes = [merge_size for _ in images]
425
+
426
+ if all(merge_size == merge_sizes[0] for merge_size in merge_sizes):
427
+ target_sizes = simple_batched_resize(
428
+ images,
429
+ factor=self.patch_size * merge_sizes[0],
430
+ min_tokens=self.min_tokens,
431
+ max_tokens=self.max_tokens,
432
+ input_data_format=input_data_format,
433
+ )
434
+ else:
435
+ target_sizes = batched_resize(
436
+ images,
437
+ factors=[self.patch_size * merge_size for merge_size in merge_sizes],
438
+ min_tokens=self.min_tokens,
439
+ max_tokens=self.max_tokens,
440
+ input_data_format=input_data_format,
441
+ )
442
+
443
+ pixel_values, grid_sizes = [], []
444
+ for image, merge_size, target_size in zip(images, merge_sizes, target_sizes):
445
+ patches, grid_size = self._preprocess(
446
+ image,
447
+ target_size=target_size,
448
+ merge_size=merge_size,
449
+ do_resize=do_resize,
450
+ resample=resample,
451
+ do_rescale=do_rescale,
452
+ rescale_factor=rescale_factor,
453
+ do_normalize=do_normalize,
454
+ image_mean=image_mean,
455
+ image_std=image_std,
456
+ data_format=data_format,
457
+ do_convert_rgb=do_convert_rgb,
458
+ input_data_format=input_data_format,
459
+ )
460
+ pixel_values.append(patches)
461
+ grid_sizes.append(grid_size)
462
+
463
+ pixel_values = np.concatenate(pixel_values, axis=0)
464
+ grid_sizes = np.array(grid_sizes)
465
+ merge_sizes = np.array(merge_sizes)
466
+
467
+ data = {
468
+ "pixel_values": pixel_values,
469
+ "grid_sizes": grid_sizes,
470
+ "merge_sizes": merge_sizes,
471
+ }
472
+
473
+ return BatchFeature(data=data, tensor_type=return_tensors)
rynnec/model/videollama3_encoder/modeling_videollama3_encoder.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py.
2
+ # Below is the original copyright:
3
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ """PyTorch VideoLLaMA3 vision encoder model."""
22
+
23
+ import importlib.util
24
+ import os.path as osp
25
+ import math
26
+ import warnings
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+ import torch.utils.checkpoint
32
+ from torch.nn.init import _calculate_fan_in_and_fan_out
33
+
34
+ from transformers.activations import ACT2FN
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import is_flash_attn_2_available
37
+
38
+ if is_flash_attn_2_available():
39
+ from flash_attn import flash_attn_varlen_func
40
+ else:
41
+ flash_attn_varlen_func = None
42
+
43
+ try:
44
+ from .configuration_videollama3_encoder import Videollama3VisionEncoderConfig
45
+ except ImportError:
46
+ spec = importlib.util.spec_from_file_location(
47
+ "configuration_videollama3_encoder",
48
+ osp.join(osp.dirname(__file__), "configuration_videollama3_encoder.py"),
49
+ )
50
+ configuration_videollama3_encoder = importlib.util.module_from_spec(spec)
51
+ spec.loader.exec_module(configuration_videollama3_encoder)
52
+ Videollama3VisionEncoderConfig = getattr(
53
+ configuration_videollama3_encoder,
54
+ "Videollama3VisionEncoderConfig",
55
+ )
56
+
57
+ LayerNorm = nn.LayerNorm
58
+
59
+
60
+ def _trunc_normal_(tensor, mean, std, a, b):
61
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
62
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
63
+ def norm_cdf(x):
64
+ # Computes standard normal cumulative distribution function
65
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
66
+
67
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
68
+ warnings.warn(
69
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
70
+ "The distribution of values may be incorrect.",
71
+ stacklevel=2,
72
+ )
73
+
74
+ # Values are generated by using a truncated uniform distribution and
75
+ # then using the inverse CDF for the normal distribution.
76
+ # Get upper and lower cdf values
77
+ l = norm_cdf((a - mean) / std)
78
+ u = norm_cdf((b - mean) / std)
79
+
80
+ # Uniformly fill tensor with values from [l, u], then translate to
81
+ # [2l-1, 2u-1].
82
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
83
+
84
+ # Use inverse cdf transform for normal distribution to get truncated
85
+ # standard normal
86
+ tensor.erfinv_()
87
+
88
+ # Transform to proper mean, std
89
+ tensor.mul_(std * math.sqrt(2.0))
90
+ tensor.add_(mean)
91
+
92
+ # Clamp to ensure it's in the proper range
93
+ tensor.clamp_(min=a, max=b)
94
+
95
+
96
+ def trunc_normal_tf_(
97
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
98
+ ) -> torch.Tensor:
99
+ """Fills the input Tensor with values drawn from a truncated
100
+ normal distribution. The values are effectively drawn from the
101
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
102
+ with values outside :math:`[a, b]` redrawn until they are within
103
+ the bounds. The method used for generating the random values works
104
+ best when :math:`a \\leq \text{mean} \\leq b`.
105
+
106
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
107
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
108
+ and the result is subsequently scaled and shifted by the mean and std args.
109
+
110
+ Args:
111
+ tensor: an n-dimensional `torch.Tensor`
112
+ mean: the mean of the normal distribution
113
+ std: the standard deviation of the normal distribution
114
+ a: the minimum cutoff value
115
+ b: the maximum cutoff value
116
+ """
117
+ with torch.no_grad():
118
+ _trunc_normal_(tensor, 0, 1.0, a, b)
119
+ tensor.mul_(std).add_(mean)
120
+
121
+
122
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
123
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
124
+ if mode == "fan_in":
125
+ denom = fan_in
126
+ elif mode == "fan_out":
127
+ denom = fan_out
128
+ elif mode == "fan_avg":
129
+ denom = (fan_in + fan_out) / 2
130
+
131
+ variance = scale / denom
132
+
133
+ if distribution == "truncated_normal":
134
+ # constant is stddev of standard normal truncated to (-2, 2)
135
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
136
+ elif distribution == "normal":
137
+ with torch.no_grad():
138
+ tensor.normal_(std=math.sqrt(variance))
139
+ elif distribution == "uniform":
140
+ bound = math.sqrt(3 * variance)
141
+ with torch.no_grad():
142
+ tensor.uniform_(-bound, bound)
143
+ else:
144
+ raise ValueError(f"invalid distribution {distribution}")
145
+
146
+
147
+ def lecun_normal_(tensor):
148
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
149
+
150
+
151
+ def default_flax_embed_init(tensor):
152
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
153
+
154
+
155
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
156
+ def rotate_half(x):
157
+ """Rotates half the hidden dims of the input."""
158
+ x1 = x[..., : x.shape[-1] // 2]
159
+ x2 = x[..., x.shape[-1] // 2 :]
160
+ return torch.cat((-x2, x1), dim=-1)
161
+
162
+
163
+ # def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
164
+ # orig_dtype = tensor.dtype
165
+ # tensor = tensor.float()
166
+ # cos = freqs.cos()
167
+ # sin = freqs.sin()
168
+ # cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
169
+ # sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
170
+ # output = (tensor * cos) + (rotate_half(tensor) * sin)
171
+ # output = output.to(orig_dtype)
172
+ # return output
173
+
174
+
175
+ def apply_rotary_pos_emb_vision(q, k, cos, sin) -> torch.Tensor:
176
+ orig_dtype = q.dtype
177
+ q, k = q.float(), k.float()
178
+ cos = cos.unsqueeze(1).float()
179
+ sin = sin.unsqueeze(1).float()
180
+ q = (q * cos) + (rotate_half(q) * sin)
181
+ k = (k * cos) + (rotate_half(k) * sin)
182
+ return q.to(orig_dtype), k.to(orig_dtype)
183
+
184
+
185
+ class VisionRotaryEmbedding(nn.Module):
186
+
187
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
188
+ super().__init__()
189
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
190
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
191
+
192
+ def forward(self, seqlen: int) -> torch.Tensor:
193
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
194
+ freqs = torch.outer(seq, self.inv_freq)
195
+ return freqs
196
+
197
+
198
+ class Videollama3VisionEmbeddings(nn.Module):
199
+
200
+ def __init__(self, config: Videollama3VisionEncoderConfig):
201
+ super().__init__()
202
+ self.config = config
203
+ self.embed_dim = config.hidden_size
204
+ self.patch_size = config.patch_size
205
+
206
+ self.patch_embedding = nn.Conv2d(
207
+ in_channels=config.num_channels,
208
+ out_channels=self.embed_dim,
209
+ kernel_size=self.patch_size,
210
+ stride=self.patch_size,
211
+ padding="valid",
212
+ )
213
+
214
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
215
+ hidden_states = hidden_states.view(
216
+ -1, self.config.num_channels, self.patch_size, self.patch_size
217
+ )
218
+ patch_embeds = self.patch_embedding(hidden_states) # shape = [*, width, grid, grid]
219
+ # embeddings = patch_embeds.flatten(2).transpose(1, 2)
220
+ embeddings = patch_embeds.view(-1, self.embed_dim)
221
+
222
+ return embeddings
223
+
224
+
225
+ class VisionAttention(nn.Module):
226
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
227
+
228
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
229
+ def __init__(self, config):
230
+ super().__init__()
231
+ self.config = config
232
+ self.embed_dim = config.hidden_size
233
+ self.num_heads = config.num_attention_heads
234
+ self.head_dim = self.embed_dim // self.num_heads
235
+ if self.head_dim * self.num_heads != self.embed_dim:
236
+ raise ValueError(
237
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
238
+ f" {self.num_heads})."
239
+ )
240
+ self.scale = self.head_dim**-0.5
241
+ self.dropout = config.attention_dropout
242
+
243
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
244
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
245
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
246
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
247
+
248
+ def forward(
249
+ self,
250
+ hidden_states: torch.Tensor,
251
+ cu_seqlens: torch.Tensor,
252
+ rotary_pos_emb: torch.Tensor = None,
253
+ ) -> torch.Tensor:
254
+ """Input shape: Time x Channel"""
255
+
256
+ q_len, _ = hidden_states.size()
257
+
258
+ query_states = self.q_proj(hidden_states)
259
+ key_states = self.k_proj(hidden_states)
260
+ value_states = self.v_proj(hidden_states)
261
+
262
+ query_states = query_states.view(q_len, self.num_heads, self.head_dim)
263
+ key_states = key_states.view(q_len, self.num_heads, self.head_dim)
264
+ value_states = value_states.view(q_len, self.num_heads, self.head_dim)
265
+
266
+ query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
267
+ key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
268
+
269
+ attention_mask = torch.zeros([1, q_len, q_len], device=query_states.device, dtype=torch.bool)
270
+ for i in range(1, len(cu_seqlens)):
271
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
272
+
273
+ query_states = query_states.transpose(0, 1)
274
+ key_states = key_states.transpose(0, 1)
275
+ value_states = value_states.transpose(0, 1)
276
+
277
+ attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim)
278
+ attn_weights = attn_weights + attention_mask
279
+
280
+ # upcast attention to fp32
281
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
282
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
283
+ attn_output = torch.matmul(attn_weights, value_states)
284
+
285
+ attn_output = attn_output.transpose(0, 1)
286
+ attn_output = attn_output.reshape(q_len, -1)
287
+ attn_output = self.out_proj(attn_output)
288
+
289
+ return attn_output
290
+
291
+
292
+ class VisionFlashAttention2(VisionAttention):
293
+
294
+ def __init__(self, *args, **kwargs):
295
+ super().__init__(*args, **kwargs)
296
+
297
+ # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
298
+ def forward(
299
+ self,
300
+ hidden_states: torch.Tensor,
301
+ cu_seqlens: torch.Tensor,
302
+ rotary_pos_emb: torch.Tensor = None,
303
+ ) -> torch.Tensor:
304
+ q_len, _ = hidden_states.size()
305
+
306
+ query_states = self.q_proj(hidden_states)
307
+ key_states = self.k_proj(hidden_states)
308
+ value_states = self.v_proj(hidden_states)
309
+
310
+ # Flash attention requires the input to have the shape
311
+ # batch_size x seq_length x head_dim x hidden_dim
312
+ # therefore we just need to keep the original shape
313
+ query_states = query_states.view(1, q_len, self.num_heads, self.head_dim).transpose(1, 2)
314
+ key_states = key_states.view(1, q_len, self.num_heads, self.head_dim).transpose(1, 2)
315
+ value_states = value_states.view(q_len, self.num_heads, self.head_dim)
316
+ # query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
317
+ # key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
318
+ query_states, key_states = apply_rotary_pos_emb_vision(
319
+ query_states,
320
+ key_states,
321
+ rotary_pos_emb.cos().unsqueeze(0).repeat(1, 1, 2),
322
+ rotary_pos_emb.sin().unsqueeze(0).repeat(1, 1, 2),
323
+ )
324
+ query_states = query_states.transpose(1, 2).squeeze(0)
325
+ key_states = key_states.transpose(1, 2).squeeze(0)
326
+
327
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
328
+ attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
329
+ q_len, -1
330
+ )
331
+ attn_output = self.out_proj(attn_output)
332
+
333
+ return attn_output
334
+
335
+
336
+ class VisionSdpaAttention(VisionAttention):
337
+
338
+ def forward(
339
+ self,
340
+ hidden_states: torch.Tensor,
341
+ cu_seqlens: torch.Tensor,
342
+ rotary_pos_emb: torch.Tensor = None,
343
+ ) -> torch.Tensor:
344
+ seq_length = hidden_states.shape[0]
345
+ query_states = self.q_proj(hidden_states)
346
+ key_states = self.k_proj(hidden_states)
347
+ value_states = self.v_proj(hidden_states)
348
+
349
+ query_states = query_states.view(seq_length, self.num_heads, self.head_dim)
350
+ key_states = key_states.view(seq_length, self.num_heads, self.head_dim)
351
+ value_states = value_states.view(seq_length, self.num_heads, self.head_dim)
352
+
353
+ query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
354
+ key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
355
+
356
+ attention_mask = torch.zeros([1, seq_length, seq_length], device=query_states.device, dtype=torch.bool)
357
+ for i in range(1, len(cu_seqlens)):
358
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
359
+
360
+ query_states = query_states.transpose(0, 1)
361
+ key_states = key_states.transpose(0, 1)
362
+ value_states = value_states.transpose(0, 1)
363
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attention_mask, dropout_p=0.0)
364
+ attn_output = attn_output.transpose(0, 1)
365
+ attn_output = attn_output.reshape(seq_length, -1)
366
+ attn_output = self.out_proj(attn_output)
367
+ return attn_output
368
+
369
+
370
+ VISION_ATTENTION_CLASSES = {
371
+ "eager": VisionAttention,
372
+ "flash_attention_2": VisionFlashAttention2,
373
+ "sdpa": VisionSdpaAttention,
374
+ }
375
+
376
+
377
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Videollama3
378
+ class Videollama3VisionMLP(nn.Module):
379
+
380
+ def __init__(self, config):
381
+ super().__init__()
382
+ self.config = config
383
+ self.activation_fn = ACT2FN[config.hidden_act]
384
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
385
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
386
+
387
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
388
+ hidden_states = self.fc1(hidden_states)
389
+ hidden_states = self.activation_fn(hidden_states)
390
+ hidden_states = self.fc2(hidden_states)
391
+ return hidden_states
392
+
393
+
394
+ class Videollama3VisionEncoderLayer(nn.Module):
395
+
396
+ def __init__(self, config: Videollama3VisionEncoderConfig):
397
+ super().__init__()
398
+ self.embed_dim = config.hidden_size
399
+ self.self_attn = VISION_ATTENTION_CLASSES[config._attn_implementation](config=config)
400
+ self.layer_norm1 = LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
401
+ self.mlp = Videollama3VisionMLP(config)
402
+ self.layer_norm2 = LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
403
+
404
+ # Ignore copy
405
+ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
406
+ hidden_states = hidden_states + self.self_attn(
407
+ self.layer_norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
408
+ )
409
+ hidden_states = hidden_states + self.mlp(self.layer_norm2(hidden_states))
410
+ return hidden_states
411
+
412
+
413
+ class Videollama3VisionTransformerEncoder(nn.Module):
414
+
415
+ def __init__(self, config: Videollama3VisionEncoderConfig):
416
+ super().__init__()
417
+ self.config = config
418
+ head_dim = config.hidden_size // config.num_attention_heads
419
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
420
+ self.layers = nn.ModuleList([Videollama3VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
421
+ self.gradient_checkpointing = False
422
+
423
+ def rot_pos_emb(self, grid_sizes, merge_sizes):
424
+ pos_ids = []
425
+ for (t, h, w), merge_size in zip(grid_sizes, merge_sizes):
426
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
427
+ hpos_ids = hpos_ids.reshape(
428
+ h // merge_size,
429
+ merge_size,
430
+ w // merge_size,
431
+ merge_size,
432
+ )
433
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
434
+ hpos_ids = hpos_ids.flatten()
435
+
436
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
437
+ wpos_ids = wpos_ids.reshape(
438
+ h // merge_size,
439
+ merge_size,
440
+ w // merge_size,
441
+ merge_size,
442
+ )
443
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
444
+ wpos_ids = wpos_ids.flatten()
445
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
446
+
447
+ pos_ids = torch.cat(pos_ids, dim=0)
448
+ max_grid_size = grid_sizes[:, 1:].max()
449
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
450
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
451
+
452
+ return rotary_pos_emb
453
+
454
+ def forward(self, hidden_states, grid_sizes, merge_sizes) -> torch.Tensor:
455
+ rotary_pos_emb = self.rot_pos_emb(grid_sizes, merge_sizes)
456
+
457
+ cu_seqlens = torch.repeat_interleave(grid_sizes[:, 1] * grid_sizes[:, 2], grid_sizes[:, 0]).cumsum(dim=0, dtype=torch.int32)
458
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
459
+
460
+ for blk in self.layers:
461
+ if self.gradient_checkpointing and self.training:
462
+ hidden_states = self._gradient_checkpointing_func(
463
+ blk.__call__,
464
+ hidden_states,
465
+ cu_seqlens,
466
+ rotary_pos_emb
467
+ )
468
+ else:
469
+ hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
470
+
471
+ return hidden_states
472
+
473
+
474
+ class Videollama3VisionEncoderModel(PreTrainedModel):
475
+
476
+ config_class = Videollama3VisionEncoderConfig
477
+ base_model_prefix = "videollama3"
478
+ main_input_name = "pixel_values"
479
+ supports_gradient_checkpointing = True
480
+ _no_split_modules = [
481
+ "Videollama3VisionEncoderLayer",
482
+ "Videollama3VisionEmbeddings",
483
+ ]
484
+ _supports_flash_attn_2 = True
485
+ _supports_sdpa = True
486
+
487
+ def __init__(self, config: Videollama3VisionEncoderConfig):
488
+ super().__init__(config=config)
489
+ embed_dim = config.hidden_size
490
+
491
+ self.embeddings = Videollama3VisionEmbeddings(config)
492
+ self.encoder = Videollama3VisionTransformerEncoder(config)
493
+ self.post_layernorm = LayerNorm(embed_dim, eps=config.layer_norm_eps)
494
+
495
+ self.post_init()
496
+
497
+ def forward(self, pixel_values, grid_sizes, merge_sizes=None):
498
+ hidden_states = self.embeddings(pixel_values)
499
+ hidden_states = self.encoder(hidden_states, grid_sizes, merge_sizes)
500
+ hidden_states = self.post_layernorm(hidden_states)
501
+ hidden_states_raw = hidden_states.clone()
502
+
503
+ hidden_states_chunks = hidden_states.split(grid_sizes.prod(dim=1).tolist(), dim=0)
504
+ outputs = []
505
+
506
+ for hidden_states, grid_size, merge_size in zip(hidden_states_chunks, grid_sizes, merge_sizes):
507
+ # NOTE: previous implementation, which supports downsampling with any factor
508
+ c = hidden_states.shape[-1]
509
+ hidden_states = hidden_states.view(
510
+ grid_size[0], grid_size[1] // merge_size, grid_size[2] // merge_size, merge_size, merge_size, c
511
+ ).permute(0, 1, 3, 2, 4, 5)
512
+ hidden_states = hidden_states.reshape(
513
+ grid_size[0], grid_size[1], grid_size[2], c
514
+ ).permute(0, 3, 1, 2)
515
+ hidden_states = torch.nn.functional.interpolate(
516
+ hidden_states,
517
+ size=(grid_size[1] // merge_size, grid_size[2] // merge_size),
518
+ mode='bilinear'
519
+ )
520
+ hidden_states = hidden_states.permute(0, 2, 3, 1).view(-1, c)
521
+
522
+ # NOTE: simplified implementation, which only supports downsampling with integer factor
523
+ # NOTE: this implementation is mathematically equivalent to the previous one when merge_size is 1 or 2 but may cause slightly different results
524
+ # hidden_states = hidden_states.view(-1, merge_size * merge_size, hidden_states.size(-1))
525
+ # hidden_states = hidden_states.mean(dim=1)
526
+
527
+ outputs.append(hidden_states)
528
+
529
+ return torch.cat(outputs, dim=0), hidden_states_raw
530
+
531
+ def _init_weights(self, module):
532
+ """Initialize the weights"""
533
+ if isinstance(module, nn.Embedding):
534
+ default_flax_embed_init(module.weight)
535
+ elif isinstance(module, VisionAttention):
536
+ nn.init.xavier_uniform_(module.q_proj.weight)
537
+ nn.init.xavier_uniform_(module.k_proj.weight)
538
+ nn.init.xavier_uniform_(module.v_proj.weight)
539
+ nn.init.xavier_uniform_(module.out_proj.weight)
540
+ nn.init.zeros_(module.q_proj.bias)
541
+ nn.init.zeros_(module.k_proj.bias)
542
+ nn.init.zeros_(module.v_proj.bias)
543
+ nn.init.zeros_(module.out_proj.bias)
544
+ elif isinstance(module, Videollama3VisionMLP):
545
+ nn.init.xavier_uniform_(module.fc1.weight)
546
+ nn.init.xavier_uniform_(module.fc2.weight)
547
+ nn.init.normal_(module.fc1.bias, std=1e-6)
548
+ nn.init.normal_(module.fc2.bias, std=1e-6)
549
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
550
+ lecun_normal_(module.weight)
551
+ if module.bias is not None:
552
+ nn.init.zeros_(module.bias)
553
+ elif isinstance(module, LayerNorm):
554
+ module.bias.data.zero_()
555
+ module.weight.data.fill_(1.0)
rynnec/rynnec_trainer.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from: https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py
2
+ import os
3
+ import logging
4
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import Sampler
9
+
10
+ from transformers import Trainer
11
+ from transformers.trainer import (
12
+ is_sagemaker_mp_enabled,
13
+ get_parameter_names,
14
+ has_length,
15
+ ALL_LAYERNORM_LAYERS,
16
+ logger,
17
+ TRAINER_STATE_NAME,
18
+ )
19
+ from transformers.utils import (
20
+ is_sagemaker_mp_enabled,
21
+ logging,
22
+ )
23
+
24
+ def maybe_zero_3(param, ignore_status=False, name=None):
25
+ from deepspeed import zero
26
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
27
+ if hasattr(param, "ds_id"):
28
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
29
+ if not ignore_status:
30
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
31
+ with zero.GatheredParameters([param]):
32
+ param = param.data.detach().cpu().clone()
33
+ else:
34
+ param = param.detach().cpu().clone()
35
+ return param
36
+
37
+
38
+ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
39
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
40
+ to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
41
+ return to_return
42
+
43
+
44
+ # Borrowed from peft.utils.get_peft_model_state_dict
45
+ def get_peft_state_maybe_zero_3(named_params, bias):
46
+ if bias == "none":
47
+ to_return = {k: t for k, t in named_params if "lora_" in k}
48
+ elif bias == "all":
49
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
50
+ elif bias == "lora_only":
51
+ to_return = {}
52
+ maybe_lora_bias = {}
53
+ lora_bias_names = set()
54
+ for k, t in named_params:
55
+ if "lora_" in k:
56
+ to_return[k] = t
57
+ bias_name = k.split("lora_")[0] + "bias"
58
+ lora_bias_names.add(bias_name)
59
+ elif "bias" in k:
60
+ maybe_lora_bias[k] = t
61
+ for k, t in maybe_lora_bias:
62
+ if bias_name in lora_bias_names:
63
+ to_return[bias_name] = t
64
+ else:
65
+ raise NotImplementedError
66
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
67
+ return to_return
68
+
69
+
70
+ def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
71
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
72
+ if require_grad_only:
73
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
74
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
75
+ return to_return
76
+
77
+
78
+ def find_all_linear_names(model):
79
+ cls = torch.nn.Linear
80
+ lora_module_names = set()
81
+ multimodal_keywords = ['mm_projector', 'vision_encoder', 'vision_resampler', 'text_hidden_fcs', 'region_encoder', 'grounding_encoder']
82
+ for name, module in model.named_modules():
83
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
84
+ continue
85
+ if isinstance(module, cls):
86
+ if 'lm_head' in name:
87
+ continue
88
+ lora_module_names.add(name)
89
+
90
+ return list(lora_module_names)
91
+
92
+ def safe_save_model_for_hf_trainer(trainer: Trainer,
93
+ output_dir: str):
94
+ """Collects the state dict and dump to disk."""
95
+
96
+ if getattr(trainer.args, "is_alignment", False):
97
+ # Only save Adapter
98
+ keys_to_match = ['mm_projector']
99
+
100
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
101
+ trainer.model.config.save_pretrained(output_dir)
102
+
103
+ current_folder = output_dir.split('/')[-1]
104
+ parent_folder = os.path.dirname(output_dir)
105
+ # if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
106
+ if torch.distributed.get_rank() == 0:
107
+ if current_folder.startswith('checkpoint-'):
108
+ mm_projector_folder = os.path.join(parent_folder, "mm_projector")
109
+ os.makedirs(mm_projector_folder, exist_ok=True)
110
+ torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
111
+ else:
112
+ torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
113
+ return
114
+
115
+ if trainer.deepspeed:
116
+ torch.cuda.synchronize()
117
+ trainer.save_model(output_dir)
118
+ return
119
+
120
+ state_dict = trainer.model.state_dict()
121
+ if trainer.args.should_save:
122
+ cpu_state_dict = {
123
+ key: value.cpu()
124
+ for key, value in state_dict.items()
125
+ }
126
+ del state_dict
127
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
128
+
129
+
130
+ def split_to_even_chunks(indices, lengths, num_chunks):
131
+ """
132
+ Split a list of indices into `chunks` chunks of roughly equal lengths.
133
+ """
134
+
135
+ if len(indices) % num_chunks != 0:
136
+ return [indices[i::num_chunks] for i in range(num_chunks)]
137
+
138
+ num_indices_per_chunk = len(indices) // num_chunks
139
+
140
+ chunks = [[] for _ in range(num_chunks)]
141
+ chunks_lengths = [0 for _ in range(num_chunks)]
142
+ for index in indices:
143
+ shortest_chunk = chunks_lengths.index(min(chunks_lengths))
144
+ chunks[shortest_chunk].append(index)
145
+ chunks_lengths[shortest_chunk] += lengths[index]
146
+ if len(chunks[shortest_chunk]) == num_indices_per_chunk:
147
+ chunks_lengths[shortest_chunk] = float("inf")
148
+
149
+ return chunks
150
+
151
+
152
+ def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
153
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
154
+ assert all(l != 0 for l in lengths), "Should not have zero length."
155
+ if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
156
+ # all samples are in the same modality
157
+ return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
158
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
159
+ lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
160
+
161
+ mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
162
+ lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
163
+ megabatch_size = world_size * batch_size
164
+ mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
165
+ lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
166
+
167
+ # last_mm = mm_megabatches[-1]
168
+ # last_lang = lang_megabatches[-1]
169
+ # additional_batch = last_mm + last_lang
170
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
171
+ megabatch_indices = torch.randperm(len(megabatches), generator=generator)
172
+ megabatches = [megabatches[i] for i in megabatch_indices]
173
+
174
+ # if len(additional_batch) > 0:
175
+ # megabatches.append(sorted(additional_batch))
176
+
177
+ return [i for megabatch in megabatches for i in megabatch]
178
+
179
+
180
+ def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
181
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
182
+ indices = torch.randperm(len(lengths), generator=generator)
183
+ megabatch_size = world_size * batch_size
184
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
185
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
186
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
187
+
188
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
189
+
190
+
191
+ class LengthGroupedSampler(Sampler):
192
+ r"""
193
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
194
+ keeping a bit of randomness.
195
+ """
196
+
197
+ def __init__(
198
+ self,
199
+ batch_size: int,
200
+ world_size: int,
201
+ lengths: Optional[List[int]] = None,
202
+ generator=None,
203
+ group_by_modality: bool = False,
204
+ ):
205
+ if lengths is None:
206
+ raise ValueError("Lengths must be provided.")
207
+
208
+ self.batch_size = batch_size
209
+ self.world_size = world_size
210
+ self.lengths = lengths
211
+ self.generator = generator
212
+ self.group_by_modality = group_by_modality
213
+
214
+ def __len__(self):
215
+ return len(self.lengths)
216
+
217
+ def __iter__(self):
218
+ if self.group_by_modality:
219
+ indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
220
+ else:
221
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
222
+ return iter(indices)
223
+
224
+
225
+ class RynnECTrainer(Trainer):
226
+
227
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
228
+ print('_get_train_sampler')
229
+ print('world size: ', self.args.world_size * self.args.gradient_accumulation_steps)
230
+ if self.train_dataset is None or not has_length(self.train_dataset):
231
+ return None
232
+ print('group_by_modality_length...')
233
+ if self.args.group_by_modality_length:
234
+ lengths = self.train_dataset.modality_lengths
235
+ return LengthGroupedSampler(
236
+ self.args.train_batch_size,
237
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps,
238
+ lengths=lengths,
239
+ group_by_modality=True,
240
+ )
241
+ else:
242
+ return super()._get_train_sampler()
243
+
244
+ def update_history_loss_dict(self,outputs):
245
+ if not hasattr(self,'history_loss_dict'):
246
+ self.history_loss_dict = {}
247
+ for name, value in outputs.items():
248
+ if 'loss' in name and name != 'loss':
249
+ if name not in self.history_loss_dict:
250
+ self.history_loss_dict[name] = value.item()
251
+ else:
252
+ if value != 0:
253
+ self.history_loss_dict[name] = value.item()
254
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
255
+ """
256
+ How the loss is computed by Trainer. By default, all models return the loss in the first element.
257
+
258
+ Subclass and override for custom behavior.
259
+ """
260
+ if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
261
+ labels = inputs.pop("labels")
262
+ else:
263
+ labels = None
264
+ if self.model_accepts_loss_kwargs:
265
+ loss_kwargs = {}
266
+ if num_items_in_batch is not None:
267
+ loss_kwargs["num_items_in_batch"] = num_items_in_batch
268
+ inputs = {**inputs, **loss_kwargs}
269
+ outputs = model(**inputs)
270
+ # Save past state if it exists
271
+ # TODO: this needs to be fixed and made cleaner later.
272
+ if self.args.past_index >= 0:
273
+ self._past = outputs[self.args.past_index]
274
+
275
+ if labels is not None:
276
+ unwrapped_model = self.accelerator.unwrap_model(model)
277
+ if _is_peft_model(unwrapped_model):
278
+ model_name = unwrapped_model.base_model.model._get_name()
279
+ else:
280
+ model_name = unwrapped_model._get_name()
281
+ # User-defined compute_loss function
282
+ if self.compute_loss_func is not None:
283
+ loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
284
+ elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
285
+ loss = self.label_smoother(outputs, labels, shift_labels=True)
286
+ else:
287
+ loss = self.label_smoother(outputs, labels)
288
+ else:
289
+ if isinstance(outputs, dict) and "loss" not in outputs:
290
+ raise ValueError(
291
+ "The model did not return a loss from the inputs, only the following keys: "
292
+ f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
293
+ )
294
+ # We don't use .loss here since the model may return tuples instead of ModelOutput.
295
+ loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
296
+ if isinstance(outputs, dict) and 'mask_bce_loss' in outputs:
297
+ loss_dict = {}
298
+ for name,value in outputs.items():
299
+ if 'loss' in name and name != 'loss':
300
+ loss_value = value.item()
301
+ if loss_value == 0 and hasattr(self,'history_loss_dict'):
302
+ loss_value = self.history_loss_dict[name]
303
+ loss_dict[name] = loss_value
304
+ self.update_history_loss_dict(outputs)
305
+ self.log(loss_dict)
306
+
307
+ if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
308
+ loss *= self.accelerator.num_processes
309
+
310
+ return (loss, outputs) if return_outputs else loss
311
+
312
+
313
+ def create_optimizer(self):
314
+ """
315
+ Setup the optimizer.
316
+
317
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
318
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
319
+ """
320
+ if is_sagemaker_mp_enabled():
321
+ return super().create_optimizer()
322
+
323
+ opt_model = self.model
324
+
325
+ if self.optimizer is None:
326
+ optimized_parameters = [(n, p) for n, p in opt_model.named_parameters() if p.requires_grad]
327
+ optimizer_grouped_parameters = []
328
+
329
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
330
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
331
+
332
+ if self.args.llm_lr is not None:
333
+ lm_parameters = [
334
+ name for name, _ in optimized_parameters if "vision_encoder" not in name and "mm_projector" not in name and "region_encoder" not in name and "grounding_encoder" not in name
335
+ ]
336
+ decay_lm_parameters = [name for name in lm_parameters if name in decay_parameters]
337
+ nodecay_lm_parameters = [name for name in lm_parameters if name not in decay_parameters]
338
+ optimizer_grouped_parameters.extend([
339
+ {
340
+ "params": [p for n, p in optimized_parameters if n in decay_lm_parameters],
341
+ "weight_decay": self.args.weight_decay,
342
+ "lr": self.args.llm_lr,
343
+ },
344
+ {
345
+ "params": [p for n, p in optimized_parameters if n in nodecay_lm_parameters],
346
+ "weight_decay": 0.0,
347
+ "lr": self.args.llm_lr,
348
+ }
349
+ ])
350
+
351
+ if self.args.mm_projector_lr is not None:
352
+ projector_parameters = [name for name, _ in optimized_parameters if "mm_projector" in name]
353
+ decay_projector_parameters = [name for name in projector_parameters if name in decay_parameters]
354
+ nodecay_projector_parameters = [name for name in projector_parameters if name not in decay_parameters]
355
+ optimizer_grouped_parameters.extend([
356
+ {
357
+ "params": [p for n, p in optimized_parameters if n in decay_projector_parameters],
358
+ "weight_decay": self.args.weight_decay,
359
+ "lr": self.args.mm_projector_lr,
360
+ },
361
+ {
362
+ "params": [p for n, p in optimized_parameters if n in nodecay_projector_parameters],
363
+ "weight_decay": 0.0,
364
+ "lr": self.args.mm_projector_lr,
365
+ }
366
+ ])
367
+
368
+ if self.args.vision_encoder_lr is not None:
369
+ vision_encoder_parameters = [name for name, _ in optimized_parameters if "vision_encoder" in name]
370
+ decay_vision_encoder_parameters = [name for name in vision_encoder_parameters if name in decay_parameters]
371
+ nodecay_vision_encoder_parameters = [name for name in vision_encoder_parameters if name not in decay_parameters]
372
+ optimizer_grouped_parameters.extend([
373
+ {
374
+ "params": [p for n, p in optimized_parameters if n in decay_vision_encoder_parameters],
375
+ "weight_decay": self.args.weight_decay,
376
+ "lr": self.args.vision_encoder_lr,
377
+ },
378
+ {
379
+ "params": [p for n, p in optimized_parameters if n in nodecay_vision_encoder_parameters],
380
+ "weight_decay": 0.0,
381
+ "lr": self.args.vision_encoder_lr,
382
+ }
383
+ ])
384
+
385
+ if self.args.region_encoder_lr is not None:
386
+ projector_parameters = [name for name, _ in optimized_parameters if "region_encoder" in name]
387
+ decay_projector_parameters = [name for name in projector_parameters if name in decay_parameters]
388
+ nodecay_projector_parameters = [name for name in projector_parameters if name not in decay_parameters]
389
+ optimizer_grouped_parameters.extend([
390
+ {
391
+ "params": [p for n, p in optimized_parameters if n in decay_projector_parameters],
392
+ "weight_decay": self.args.weight_decay,
393
+ "lr": self.args.region_encoder_lr,
394
+ },
395
+ {
396
+ "params": [p for n, p in optimized_parameters if n in nodecay_projector_parameters],
397
+ "weight_decay": 0.0,
398
+ "lr": self.args.region_encoder_lr,
399
+ }
400
+ ])
401
+ if self.args.sam_decoder_lr is not None:
402
+ projector_parameters = [name for name, _ in optimized_parameters if "grounding_encoder" in name and "image_encoder" not in name]
403
+ decay_projector_parameters = [name for name in projector_parameters if name in decay_parameters]
404
+ nodecay_projector_parameters = [name for name in projector_parameters if name not in decay_parameters]
405
+ optimizer_grouped_parameters.extend([
406
+ {
407
+ "params": [p for n, p in optimized_parameters if n in decay_projector_parameters],
408
+ "weight_decay": self.args.weight_decay,
409
+ "lr": self.args.sam_decoder_lr,
410
+ },
411
+ {
412
+ "params": [p for n, p in optimized_parameters if n in nodecay_projector_parameters],
413
+ "weight_decay": 0.0,
414
+ "lr": self.args.sam_decoder_lr,
415
+ }
416
+ ])
417
+
418
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
419
+
420
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
421
+ if optimizer_cls.__name__ == "Adam8bit":
422
+ import bitsandbytes
423
+
424
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
425
+
426
+ skipped = 0
427
+ for module in opt_model.modules():
428
+ if isinstance(module, nn.Embedding):
429
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
430
+ logger.info(f"skipped {module}: {skipped/2**20}M params")
431
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
432
+ logger.debug(f"bitsandbytes: will optimize {module} in fp32")
433
+ logger.info(f"skipped: {skipped/2**20}M params")
434
+
435
+ return self.optimizer
436
+
437
+ def _save_checkpoint(self, model, trial, metrics=None):
438
+ if getattr(self.args, 'is_alignment', False):
439
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
440
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
441
+
442
+ run_dir = self._get_output_dir(trial=trial)
443
+ output_dir = os.path.join(run_dir, checkpoint_folder)
444
+
445
+ # Only save Adapter
446
+ keys_to_match = ['mm_projector', 'vision_resampler']
447
+
448
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
449
+
450
+ if self.args.local_rank == 0 or self.args.local_rank == -1:
451
+ self.model.config.save_pretrained(output_dir)
452
+ torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
453
+ # Save optimizer and scheduler
454
+ self._save_optimizer_and_scheduler(output_dir)
455
+ # Save RNG state
456
+ self._save_rng_state(output_dir)
457
+ self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
458
+ self.args.distributed_state.wait_for_everyone()
459
+ else:
460
+ # NOTE: Supporting save complete lora checkpoint during training.
461
+ if self.args.lora_enable:
462
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
463
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
464
+
465
+ run_dir = self._get_output_dir(trial=trial)
466
+ output_dir = os.path.join(run_dir, checkpoint_folder)
467
+
468
+ state_dict = get_peft_state_maybe_zero_3(self.model.named_parameters(), self.args.lora_bias)
469
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(self.model.named_parameters())
470
+
471
+ # add for qwen2
472
+ if hasattr(self.model, 'base_model') and hasattr(self.model.base_model, 'lm_head'):
473
+ lm_head_weight = self.model.base_model.lm_head.weight.cpu()
474
+ non_lora_state_dict['base_model.lm_head.weight'] = lm_head_weight
475
+ print("add base_model.lm_head.weight")
476
+ else:
477
+ print("The model does not have 'base_model.lm_head.weight' attribute.")
478
+
479
+
480
+ if self.args.local_rank == 0 or self.args.local_rank == -1:
481
+ # save for acquring `config.json`
482
+ self.model.config.save_pretrained(output_dir)
483
+ # save for acquring `adapter_config.json`, `adapter_model.bin`
484
+ # self.model.save_pretrained(output_dir, state_dict=state_dict)
485
+ torch.save(non_lora_state_dict, os.path.join(output_dir, 'non_lora_trainables.bin'))
486
+
487
+ # save for acquring lora adapter parameters & trainer states: `adapter_config.json`, `adapter_model.safetensors`
488
+ super(RynnECTrainer, self)._save_checkpoint(model, trial, metrics)
489
+ else:
490
+ super(RynnECTrainer, self)._save_checkpoint(model, trial, metrics)
491
+
492
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
493
+ if getattr(self.args, 'is_alignment', False):
494
+ pass
495
+ else:
496
+ super(RynnECTrainer, self)._save(output_dir, state_dict)
rynnec/train.py ADDED
@@ -0,0 +1,832 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/DAMO-NLP-SG/VideoLLaMA3. Below is the original copyright:
2
+ # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
3
+ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
4
+ # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
5
+ # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ import math
20
+ import copy
21
+ import json
22
+ import os
23
+ import pathlib
24
+ import random
25
+ import re
26
+ import sys
27
+ import warnings
28
+ import traceback
29
+ from packaging import version
30
+ from dataclasses import dataclass, field
31
+ from typing import Dict, List, Optional, Sequence
32
+ import numpy as np
33
+ import pyarrow as pa
34
+
35
+ # torch-related packages
36
+ # NOTE: torch must be imported before transformers. Otherwise, `Segmentation fault (core dumped)` will occur.
37
+ import torch
38
+ import transformers
39
+ from packaging import version
40
+ import datasets
41
+ from datasets import load_dataset, concatenate_datasets
42
+ from torch.utils.data import Dataset
43
+ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
44
+ from transformers import logging
45
+ # logging.set_verbosity_error()
46
+
47
+ sys.path.append('./')
48
+
49
+ from rynnec.constants import (IGNORE_INDEX, MODAL_INDEX_MAP,
50
+ NUM_FRAMES, DEFAULT_IMAGE_TOKEN, STREAM_MAX_FRAMES,
51
+ STREAM_DOWNSAMPLING, STREAM_FPS, STREAM_IMAGE_SIZE,
52
+ STREAM_START_TOKEN, STREAM_END_TOKEN, REGION_TOKEN, SEG_TOKEN, REGION_TOKEN_REPLACE)
53
+ from rynnec.mm_utils import (load_images, load_video, DirectResize, load_video_from_ids,
54
+ tokenizer_multimodal_token, annToMask, sam_preprocess_batch)
55
+ from rynnec.model import *
56
+ from rynnec.rynnec_trainer import (
57
+ RynnECTrainer, find_all_linear_names, get_peft_state_maybe_zero_3,
58
+ get_peft_state_non_lora_maybe_zero_3, safe_save_model_for_hf_trainer)
59
+
60
+ # NOTE: fast tokenizer warning issue: https://github.com/huggingface/transformers/issues/5486
61
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
62
+
63
+ local_rank = None
64
+
65
+
66
+ def rank0_print(*args):
67
+ if local_rank == 0:
68
+ print(*args)
69
+
70
+
71
+ def set_seed(seed=42):
72
+ """
73
+ Set the random seed for reproducible results.
74
+
75
+ :param seed: An integer value to be used as the random seed.
76
+ """
77
+ torch.manual_seed(seed)
78
+ torch.cuda.manual_seed(seed)
79
+ torch.cuda.manual_seed_all(seed) # for multi-GPU setups
80
+ torch.backends.cudnn.deterministic = True
81
+ torch.backends.cudnn.benchmark = False
82
+
83
+
84
+ def int_with_none(value):
85
+ if value == 'None':
86
+ return None
87
+ return int(value)
88
+
89
+
90
+ @dataclass
91
+ class ModelArguments:
92
+ # LLM Arguments
93
+ model_type: Optional[str] = field(default="rynnec", metadata={"help": "Model type selected in the list: " + ", ".join('rynnec_qwen2')})
94
+ model_path: Optional[str] = field(default="lmsys/vicuna-7b-v1.5")
95
+ version: Optional[str] = field(default="v1", metadata={"help": "Version of the conversation template."})
96
+ freeze_backbone: bool = field(default=False, metadata={"help": "Whether to freeze the LLM backbone."})
97
+ # Connector Arguments
98
+ mm_projector_type: Optional[str] = field(default='linear')
99
+ pretrain_mm_projector: Optional[str] = field(default=None)
100
+ # Vision tower Arguments
101
+ vision_encoder: Optional[str] = field(default=None)
102
+ mm_vision_select_layer: Optional[int] = field(default=-1)
103
+ mm_vision_select_feature: Optional[str] = field(default="patch")
104
+ mm_attn_implementation: Optional[str] = field(default="flash_attention_2")
105
+ # Token downsampling Arguments
106
+ spatial_merge_size: Optional[int] = field(default=1)
107
+ mm_max_length: Optional[int] = field(default=10240)
108
+ use_token_compression: Optional[bool] = field(default=False)
109
+ mask_decoder_model: Optional[str] = field(default="./checkpoints/sam2_hiera_large.pt")
110
+ load_sam2_weight: Optional[bool] = field(default=False)
111
+ training: Optional[bool] = field(default=True)
112
+ has_mask: Optional[bool] = field(default=True)
113
+
114
+ @dataclass
115
+ class DataArguments:
116
+ # Path Arguments
117
+ data_path: List[str] = field(default=None, metadata={"help": "Path to the training data."})
118
+ # image_folder: Optional[str] = field(default=None)
119
+ # video_folder: Optional[str] = field(default=None)
120
+ data_folder: Optional[str] = field(default=None)
121
+ # Loading Arguments
122
+ is_multimodal: bool = False
123
+ fps: Optional[int] = field(default=None)
124
+ max_frames: Optional[int_with_none] = field(default=None)
125
+ # Preprocess Arguments
126
+ image_aspect_ratio: str = 'square'
127
+ use_batch_flattening: bool = field(default=False, metadata={"help": "Whether to flatten the in-batch sequences of variable lengths."})
128
+ dataset_cache_dir: Optional[str] = field(default=None)
129
+
130
+
131
+ @dataclass
132
+ class TrainingArguments(transformers.TrainingArguments):
133
+ # shut auto processing (_remove_unused_columns) of transformers Trainer
134
+ remove_unused_columns: bool = field(default=False)
135
+
136
+ optim: str = field(default="adamw_torch")
137
+ # Training learning rate Arguments
138
+ vision_encoder_lr: Optional[float] = None
139
+ mm_projector_lr: Optional[float] = None
140
+ llm_lr: Optional[float] = None
141
+ region_encoder_lr: Optional[float] = None
142
+ sam_encoder_lr: Optional[float] = None
143
+ sam_decoder_lr: Optional[float] = None
144
+ # Training Data Arguments
145
+ group_by_modality_length: bool = field(default=False)
146
+ model_max_length: int = field(
147
+ default=512,
148
+ metadata={
149
+ "help":
150
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
151
+ },
152
+ )
153
+ # Lora or Quant Arguments
154
+ double_quant: bool = field(
155
+ default=True,
156
+ metadata={"help": "Compress the quantization statistics through double quantization."}
157
+ )
158
+ quant_type: str = field(
159
+ default="nf4",
160
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
161
+ )
162
+ bits: int = field(
163
+ default=16,
164
+ metadata={"help": "How many bits to use."}
165
+ )
166
+ lora_enable: bool = False
167
+ lora_r: int = 64
168
+ lora_alpha: int = 16
169
+ lora_dropout: float = 0.05
170
+ lora_weight_path: str = ""
171
+ lora_bias: str = "none"
172
+
173
+ use_workload_balancing: bool = field(default=False, metadata={"help": "Whether to use data balancing."})
174
+ loss_reduction_scope: str = field(default="batch", metadata={"help": "Loss reduction scope."})
175
+ context_parallel_size: int = field(default=1, metadata={"help": "Context parallel size."})
176
+ use_liger_kernel: bool = field(default=False, metadata={"help": "Whether to use Liger Kernel."})
177
+
178
+
179
+
180
+ class LazySupervisedDataset(Dataset):
181
+ """Dataset for supervised fine-tuning."""
182
+
183
+ def __init__(self, data_path: str, vlprocessor, data_args: DataArguments):
184
+ super(LazySupervisedDataset, self).__init__()
185
+ data_objs = []
186
+
187
+ try:
188
+ for data in data_path:
189
+ # NOTE: load_dataset can process both json or jsonl files
190
+ if data.endswith(".json") or data.endswith(".jsonl"):
191
+ data_objs.append(load_dataset("json", data_files=data, cache_dir=data_args.dataset_cache_dir)["train"])
192
+ else:
193
+ raise Exception(f"Unsupported file format (<{data}>)!")
194
+ list_data_dict = concatenate_datasets(data_objs)
195
+ except:
196
+ traceback.print_exc()
197
+ # NOTE: compatible with the old version
198
+ list_data_dict = []
199
+ for data in data_path:
200
+ if data.endswith(".json"):
201
+ data = json.load(open(data, "r"))
202
+ for i in data:
203
+ i['id'] = len(list_data_dict)
204
+ list_data_dict.append(i)
205
+ elif data.endswith(".jsonl"):
206
+ with open(data, "r", encoding="utf-8") as fp:
207
+ for line in fp:
208
+ line = line.strip()
209
+ obj = json.loads(line)
210
+ obj["id"] = len(list_data_dict)
211
+ list_data_dict.append(obj)
212
+ else:
213
+ raise Exception(f"Unsupported file format (<{data}>)!!!")
214
+
215
+
216
+ rank0_print("Formatting inputs...Skip in lazy mode")
217
+ self.vlprocessor = vlprocessor
218
+ self.list_data_dict = list_data_dict
219
+ self.data_args = data_args
220
+
221
+ img_size=1024
222
+ self.img_size = img_size
223
+ self.sam_transform = DirectResize(img_size)
224
+
225
+ def __len__(self):
226
+ return len(self.list_data_dict)
227
+
228
+ @property
229
+ def lengths(self):
230
+ length_list = []
231
+ for sample in self.list_data_dict:
232
+ img_tokens = 576 if 'image' in sample else 0
233
+ length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
234
+ return length_list
235
+
236
+ @property
237
+ def modality_lengths(self):
238
+ length_list = []
239
+ for sample in self.list_data_dict:
240
+ cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
241
+ if cur_len==0:
242
+ cur_len = 1
243
+ cur_len = cur_len if 'masks' in sample and sample['masks'] is not None and ('seg' not in sample or sample['seg'] is None) else -cur_len
244
+ length_list.append(cur_len)
245
+ return length_list
246
+
247
+ def _convert_normal(self, data_dict):
248
+ data_folder = self.data_args.data_folder
249
+ conversation = copy.deepcopy(data_dict["conversations"])
250
+
251
+ # data sanity check and repair
252
+ start_idx = 0
253
+ for sentence in conversation:
254
+ if sentence["from"] == "human" or sentence["from"] == "system":
255
+ break
256
+ start_idx += 1
257
+ if start_idx > 0:
258
+ warnings.warn(f"Find {start_idx} non-user sentences at the beginning of the conversation, remove them automatically!")
259
+ conversation = conversation[start_idx:]
260
+ assert len(conversation) > 1, f"Invalid conversation"
261
+
262
+ mask_ids = []
263
+
264
+ if 'image' in data_dict and data_dict['image'] is not None:
265
+ modal = 'image'
266
+ if all(not "<image>" in sentence["value"] for sentence in conversation):
267
+ warnings.warn(f"Image tag not found in the conversation, add it automatically at the beginning!")
268
+ conversation[0]["value"] = "<image>" + conversation[0]["value"]
269
+ image_file = data_dict['image']
270
+ if isinstance(image_file, list):
271
+ image_file = [os.path.join(data_folder, f) for f in image_file]
272
+ else:
273
+ image_file = os.path.join(data_folder, image_file)
274
+ images = load_images(image_file)
275
+
276
+ masks = []
277
+ if 'masks' in data_dict and data_dict['masks'] is not None:
278
+ if 'height' in data_dict:
279
+ h = data_dict['height']
280
+ w = data_dict['width']
281
+ else:
282
+ h = None
283
+ w = None
284
+
285
+ if isinstance(data_dict['masks'], str):
286
+ masks_ = json.load(open(data_dict['masks']))
287
+ else:
288
+ masks_= data_dict['masks']
289
+ image2maskids = []
290
+ mask_idx = 0
291
+ for ann in masks_:
292
+ image2maskids_ = []
293
+ mask = annToMask(ann, h, w)
294
+ masks.append(mask)
295
+ mask_ids.append(0)
296
+ image2maskids_.append(mask_idx)
297
+ mask_idx += 1
298
+ image2maskids.append(image2maskids_)
299
+ masks = np.stack(masks, axis=0)
300
+ masks = torch.from_numpy(masks)
301
+
302
+ seg_flag = False
303
+ for conv in conversation:
304
+ conv['value'] = conv['value'].replace(REGION_TOKEN_REPLACE, f'[{REGION_TOKEN}]')
305
+ if SEG_TOKEN in conv['value']:
306
+ seg_flag = True
307
+ if seg_flag is False:
308
+ image2maskids = []
309
+ else:
310
+ mask_ids = [-10000 for i in range(len(mask_ids))]
311
+ else:
312
+ image2maskids = []
313
+ masks = torch.zeros((1, 336, 336))
314
+ mask_ids.append(-10000)
315
+
316
+ elif 'video' in data_dict and data_dict['video'] is not None:
317
+ modal = 'video'
318
+ if all(not "<video>" in sentence["value"] for sentence in conversation):
319
+ warnings.warn(f"Video tag not found in the conversation, add it automatically at the beginning!")
320
+ conversation[0]["value"] = "<video>" + conversation[0]["value"]
321
+ if 'video_root' in data_dict and data_dict['video_root'] is not None:
322
+ video_root = data_dict['video_root']
323
+ video_file = [os.path.join(video_root,d) for d in data_dict['video']]
324
+ else:
325
+ video_file = data_dict['video']
326
+
327
+ if not isinstance(video_file, list):
328
+ video_file = [video_file]
329
+ if isinstance(video_file, list) and len(video_file) == 1 and ('timestamps' not in data_dict or data_dict['timestamps'] is None):
330
+ video_file = os.path.join(data_folder, video_file[0])
331
+ must_sample_frames = []
332
+ if 'masks' in data_dict and data_dict['masks'] is not None:
333
+ if isinstance(data_dict['masks'], str):
334
+ masks_ = json.load(open(data_dict['masks']))
335
+ else:
336
+ masks_= data_dict['masks']
337
+ for ann in masks_:
338
+ for k in ann.keys():
339
+ must_sample_frames.append(int(k))
340
+ images, timestamps, mask_ids = load_video_from_ids(video_file, fps=self.data_args.fps, max_frames=self.data_args.max_frames, must_sample_frames=must_sample_frames)
341
+ elif isinstance(video_file, list): #images
342
+ images = []
343
+ for vf in video_file:
344
+ images+=load_images(os.path.join(data_folder, vf))
345
+ timestamps = data_dict['timestamps']
346
+
347
+ else:
348
+ raise ValueError(f"Unsupported video format: {video_file}")
349
+ images = [images]
350
+ masks = []
351
+ mask_nums = []
352
+ image2maskids = []
353
+ maskid = 0
354
+
355
+ if 'masks' in data_dict and data_dict['masks'] is not None:
356
+ if 'mask_ids' in data_dict and data_dict['mask_ids'] is not None:
357
+ mask_ids = data_dict["mask_ids"]
358
+ if 'height' in data_dict:
359
+ h = data_dict['height']
360
+ w = data_dict['width']
361
+ else:
362
+ h = None
363
+ w = None
364
+
365
+ if isinstance(data_dict['masks'], str):
366
+ masks_ = json.load(open(data_dict['masks']))
367
+ else:
368
+ masks_= data_dict['masks']
369
+ for ann in masks_:
370
+ image2maskids_ = [None]*len(video_file)
371
+ for k in ann.keys():
372
+ mask = annToMask(ann[k], h, w)
373
+ masks.append(mask)
374
+ image2maskids_[mask_ids[maskid]] = maskid
375
+ maskid+=1
376
+ image2maskids.append(image2maskids_)
377
+
378
+ mask_nums.append(len(ann.keys()))
379
+ masks = np.stack(masks, axis=0)
380
+ masks = torch.from_numpy(masks)
381
+
382
+ conv_i = 0
383
+ region_num = 0
384
+ seg_flag = False
385
+ for idx in range(len(mask_nums)):
386
+ while '<region>' not in conversation[conv_i]['value'] and conv_i<len(conversation)-1:
387
+ conv_i+=1
388
+ conversation[conv_i]['value'] = conversation[conv_i]['value'].replace('<region>', "["+REGION_TOKEN*mask_nums[idx]+"]", 1)
389
+ region_num += mask_nums[idx]
390
+ if '[SEG]' in conversation[conv_i]['value']:
391
+ seg_flag = True
392
+
393
+ if seg_flag is False:
394
+ image2maskids = []
395
+ else:
396
+ mask_ids = [-10000 for i in range(len(mask_ids))]
397
+ # assert region_num == len(masks), f"error in {conversation}"
398
+
399
+ else:
400
+ image2maskids = []
401
+ masks = torch.zeros((1, 336, 336))
402
+ mask_ids.append(-10000)
403
+
404
+ else:
405
+ modal = 'text'
406
+ image2maskids = []
407
+ images = None
408
+ masks = torch.zeros((1, 336, 336))
409
+ sam_size = (336, 336)
410
+ sam_images = torch.zeros(1, 3, self.img_size, self.img_size)
411
+ mask_ids = [-10000]
412
+
413
+ if images is not None and len(images)>0:
414
+ sam_images = []
415
+ sam_size = None
416
+ if modal=='video':
417
+ for image in images[0]:
418
+ sam_image = self.sam_transform.apply_image(np.array(image))
419
+ sam_images.append(sam_image)
420
+ if sam_size is None:
421
+ sam_size = sam_image.shape[:2]
422
+ else:
423
+ for image in images:
424
+ sam_image = self.sam_transform.apply_image(np.array(image))
425
+ sam_images.append(sam_image)
426
+ if sam_size is None:
427
+ sam_size = sam_image.shape[:2]
428
+ sam_images = np.array(sam_images)
429
+ sam_images = torch.from_numpy(sam_images).permute(0, 3, 1, 2).contiguous()
430
+ sam_images = sam_preprocess_batch(sam_images)
431
+
432
+ messages = []
433
+ for conv in conversation:
434
+ if conv["from"] == "human":
435
+ # replace video tag to image tag for unified processing
436
+ # conv["value"] = conv["value"].replace("<video>", "<image>" * len(images))
437
+ chunks = conv["value"].split("<image>" if modal == 'image' else "<video>")
438
+ messages.append({
439
+ "role": "user",
440
+ "content": []
441
+ })
442
+
443
+ for chunk_idx in range(1, 2 * len(chunks)):
444
+ if chunk_idx % 2 == 1:
445
+ chunk = chunks[chunk_idx // 2].strip()
446
+ messages[-1]["content"].append({"type": "text", "text": chunk}) if chunk else None
447
+ else:
448
+ if modal == 'image':
449
+ messages[-1]["content"].append({"type": "image"})
450
+ elif modal == 'video':
451
+ messages[-1]["content"].append({"type": "video", "num_frames": len(images[0]), "time": timestamps})
452
+ else:
453
+ messages.append({
454
+ "role": "assistant",
455
+ "content": conv['value']
456
+ })
457
+
458
+ # TODO: dynamic downsampling
459
+ # image_downsampling = self.data_args.spatial_merge_size
460
+ image_downsampling = 2 if modal == "video" else 1
461
+
462
+ return modal, images, messages, image_downsampling, masks, mask_ids, sam_images, sam_size, image2maskids
463
+
464
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
465
+ data_dict = self.list_data_dict[i]
466
+
467
+ try:
468
+ modal, images, messages, image_downsampling, masks, mask_ids, sam_images, sam_size, image2maskids = self._convert_normal(data_dict)
469
+
470
+ data_dict = self.vlprocessor(
471
+ images=images,
472
+ text=messages,
473
+ merge_size=image_downsampling,
474
+ return_labels=True,
475
+ return_tensors="pt",
476
+ )
477
+
478
+ if modal == 'text':
479
+ unit_size = self.vlprocessor.image_processor.patch_size**2 * 3
480
+ data_dict['pixel_values'] = torch.zeros(self.vlprocessor.image_merge_size**2, unit_size)
481
+ data_dict['grid_sizes'] = torch.as_tensor([[1, self.vlprocessor.image_merge_size, self.vlprocessor.image_merge_size]])
482
+ data_dict['merge_sizes'] = torch.as_tensor([self.vlprocessor.image_merge_size])
483
+ elif modal == 'image' or modal == 'video':
484
+ assert len(data_dict['pixel_values']) > 0 and len(data_dict['grid_sizes']) > 0, f"Invalid image data: {data_dict['pixel_values']}, {data_dict['grid_sizes']}"
485
+
486
+ data_dict['modals'] = [modal] if isinstance(modal, str) else modal
487
+ data_dict['masks'] = masks
488
+ data_dict['mask_ids'] = mask_ids
489
+ data_dict['idx'] = i
490
+ data_dict['sam_images'] = sam_images
491
+ data_dict['sam_size'] = sam_size
492
+ data_dict['image2maskids'] = image2maskids
493
+
494
+ except Exception as e:
495
+ traceback.print_exc()
496
+ backup_idx = random.randint(0, len(self.list_data_dict) - 1)
497
+ print(f"Encounted error when process {i}-th example: {data_dict}, use {backup_idx}-th example instead!!!")
498
+ return self.__getitem__(backup_idx)
499
+
500
+ return data_dict
501
+
502
+
503
+ @dataclass
504
+ class DataCollatorForSupervisedDataset(object):
505
+ """Collate examples for supervised fine-tuning."""
506
+
507
+ vlprocessor: transformers.ProcessorMixin
508
+
509
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
510
+ input_ids, labels = tuple([instance[key] for instance in instances]
511
+ for key in ("input_ids", "labels"))
512
+ input_ids = torch.nn.utils.rnn.pad_sequence(
513
+ input_ids,
514
+ batch_first=True,
515
+ padding_value=self.vlprocessor.tokenizer.pad_token_id)
516
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
517
+ batch_first=True,
518
+ padding_value=IGNORE_INDEX)
519
+ input_ids = input_ids[:, :self.vlprocessor.tokenizer.model_max_length]
520
+ labels = labels[:, :self.vlprocessor.tokenizer.model_max_length]
521
+ attention_mask = input_ids.ne(self.vlprocessor.tokenizer.pad_token_id)
522
+ position_ids = attention_mask.cumsum(dim=-1) - 1
523
+
524
+ batch = dict(
525
+ input_ids=input_ids,
526
+ labels=labels,
527
+ attention_mask=input_ids.ne(self.vlprocessor.tokenizer.pad_token_id),
528
+ position_ids=position_ids
529
+ )
530
+
531
+ # work for 'images' argument in `prepare_inputs_labels_for_multimodal`
532
+ batch["pixel_values"] = torch.cat([x["pixel_values"] for x in instances])
533
+ batch["grid_sizes"] = torch.cat([x["grid_sizes"] for x in instances])
534
+ batch["merge_sizes"] = torch.cat([x["merge_sizes"] for x in instances])
535
+ batch["modals"] = sum([x["modals"] for x in instances], [])
536
+
537
+ batch['mask_ids'] = []
538
+ mask_idx_start = 0
539
+ for instance in instances:
540
+ if len(instance['mask_ids'])>0:
541
+ batch['mask_ids'].extend([idx+mask_idx_start for idx in instance['mask_ids']])
542
+ # print(int(instance['grid_sizes'][0][0]))
543
+
544
+ mask_idx_start += int(instance['grid_sizes'][0][0])
545
+ batch["masks"] = [x["masks"] for x in instances]
546
+ batch["sam_images"] = [x["sam_images"] for x in instances]
547
+ batch["sam_size"] = [x["sam_size"] for x in instances]
548
+ batch["image2maskids"] = [x["image2maskids"] for x in instances]
549
+ batch["idxes"] = [x["idx"] for x in instances]
550
+ return batch
551
+
552
+
553
+ def make_supervised_data_module(vlprocessor, data_args) -> Dict:
554
+ """Make dataset and collator for supervised fine-tuning."""
555
+ train_dataset = LazySupervisedDataset(
556
+ vlprocessor=vlprocessor,
557
+ # data_folder=data_args.data_folder,
558
+ data_path=data_args.data_path,
559
+ data_args=data_args
560
+ )
561
+ data_collator = DataCollatorForSupervisedDataset(vlprocessor=vlprocessor)
562
+ return dict(train_dataset=train_dataset,
563
+ eval_dataset=None,
564
+ data_collator=data_collator)
565
+
566
+
567
+ def train(attn_implementation=None):
568
+ global local_rank
569
+ set_seed(42)
570
+
571
+ parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
572
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
573
+
574
+ local_rank = training_args.local_rank
575
+
576
+ if local_rank == 0:
577
+ print('------model args------')
578
+ print(model_args)
579
+ print('------data args------')
580
+ print(data_args)
581
+ print('------training args------')
582
+ print(training_args)
583
+
584
+ compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
585
+
586
+ bnb_model_from_pretrained_args = {}
587
+ if training_args.bits in [4, 8]:
588
+ from transformers import BitsAndBytesConfig
589
+ bnb_model_from_pretrained_args.update(dict(
590
+ # device_map={"": training_args.device},
591
+ # BUG: High version transformers report error:
592
+ # ValueError: You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing `quantization_config` argument at the same time
593
+ # load_in_4bit=training_args.bits == 4,
594
+ # load_in_8bit=training_args.bits == 8,
595
+ quantization_config=BitsAndBytesConfig(
596
+ load_in_4bit=training_args.bits == 4,
597
+ load_in_8bit=training_args.bits == 8,
598
+ llm_int8_skip_modules=["mm_projector"],
599
+ llm_int8_threshold=6.0,
600
+ llm_int8_has_fp16_weight=False,
601
+ bnb_4bit_compute_dtype=compute_dtype,
602
+ bnb_4bit_use_double_quant=training_args.double_quant,
603
+ bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'}
604
+ bnb_4bit_quant_storage=compute_dtype,
605
+ )
606
+ ))
607
+
608
+ config = RynnecQwen2Config.from_pretrained(model_args.model_path)
609
+
610
+ config._attn_implementation = attn_implementation
611
+ # NOTE: active spatial_merge_size arguments
612
+ config.spatial_merge_size = model_args.spatial_merge_size
613
+ config.mm_max_length = model_args.mm_max_length
614
+ config.use_token_compression = model_args.use_token_compression
615
+ config.loss_reduction_scope = training_args.loss_reduction_scope
616
+ config.mask_decoder_model = model_args.mask_decoder_model
617
+ config.training = model_args.training
618
+ config.has_mask = model_args.has_mask
619
+
620
+ if model_args.vision_encoder is not None:
621
+ model = RynnecQwen2ForCausalLM.from_pretrained(
622
+ model_args.model_path,
623
+ config=config,
624
+ torch_dtype=compute_dtype,
625
+ do_sample=True,
626
+ **bnb_model_from_pretrained_args
627
+ )
628
+ if 'mixtral' in model_args.model_type:
629
+ import deepspeed
630
+ deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
631
+ else:
632
+ model = transformers.LlamaForCausalLM.from_pretrained(
633
+ model_args.model_path,
634
+ config=config,
635
+ torch_dtype=compute_dtype,
636
+ do_sample=True,
637
+ **bnb_model_from_pretrained_args
638
+ )
639
+ model.config.use_cache = False
640
+ if model_args.freeze_backbone:
641
+ model.model.requires_grad_(False)
642
+
643
+ if training_args.bits in [4, 8]:
644
+ from peft import prepare_model_for_kbit_training
645
+ model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
646
+ model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
647
+
648
+ if training_args.gradient_checkpointing:
649
+ if hasattr(model, "enable_input_require_grads"):
650
+ model.enable_input_require_grads()
651
+ else:
652
+ def make_inputs_require_grad(module, input, output):
653
+ output.requires_grad_(True)
654
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
655
+
656
+ if training_args.lora_enable:
657
+ from peft import LoraConfig, get_peft_model
658
+ lora_config = LoraConfig(
659
+ r=training_args.lora_r,
660
+ lora_alpha=training_args.lora_alpha,
661
+ target_modules=find_all_linear_names(model),
662
+ lora_dropout=training_args.lora_dropout,
663
+ bias=training_args.lora_bias,
664
+ task_type="CAUSAL_LM",
665
+ )
666
+ if training_args.bits == 16:
667
+ if training_args.bf16:
668
+ model.to(torch.bfloat16)
669
+ if training_args.fp16:
670
+ model.to(torch.float16)
671
+ rank0_print("Adding LoRA adapters...")
672
+ model = get_peft_model(model, lora_config)
673
+
674
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
675
+ model_args.model_path,
676
+ model_max_length=training_args.model_max_length,
677
+ padding_side="right",
678
+ use_fast=True,
679
+ )
680
+
681
+ if tokenizer.pad_token is None:
682
+ tokenizer.pad_token = tokenizer.unk_token
683
+
684
+ if model_args.vision_encoder is not None:
685
+ # initialize vision encoder + multi-modal projector
686
+ model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp)
687
+
688
+ if model_args.load_sam2_weight is True:
689
+ model.get_model().build_mask_decoder(model.get_model().config)
690
+ model.load_sam2_weights(model_args.mask_decoder_model)
691
+
692
+ vision_encoder = model.get_vision_encoder()
693
+ vision_encoder.to(dtype=compute_dtype, device=training_args.device)
694
+
695
+ vision_encoder.image_processor.max_tokens = model_args.mm_max_length
696
+ mm_projector = model.get_mm_projector()
697
+ mm_projector.to(dtype=compute_dtype if training_args.bf16 else torch.float16, device=training_args.device)
698
+
699
+ data_args.is_multimodal = True
700
+
701
+ model.config.tokenizer_padding_side = tokenizer.padding_side
702
+ model.config.tokenizer_model_max_length = tokenizer.model_max_length
703
+
704
+ if training_args.bits in [4, 8]:
705
+ model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
706
+
707
+ # decoupled learning rate
708
+ model.config.llm_lr = training_args.llm_lr
709
+ model.config.vision_encoder_lr = training_args.vision_encoder_lr
710
+ model.config.mm_projector_lr = training_args.mm_projector_lr
711
+ model.config.region_encoder_lr = training_args.region_encoder_lr
712
+ model.config.sam_decoder_lr = training_args.sam_decoder_lr
713
+ model.config.sam_encoder_lr = training_args.sam_encoder_lr
714
+ model.config.dice_loss_weight = 0.5
715
+ model.config.bce_loss_weight = 2.0
716
+
717
+ if model.config.llm_lr is None:
718
+ for p in model.get_model().parameters():
719
+ p.requires_grad = False
720
+ for p in model.get_model().vision_encoder.parameters():
721
+ p.requires_grad = True
722
+ for p in model.get_model().mm_projector.parameters():
723
+ p.requires_grad = True
724
+ for p in model.get_model().region_encoder.parameters():
725
+ p.requires_grad = True
726
+
727
+
728
+ if model.config.vision_encoder_lr is None:
729
+ for p in model.get_model().vision_encoder.parameters():
730
+ p.requires_grad = False
731
+
732
+ if model.config.mm_projector_lr is None:
733
+ for p in model.get_model().mm_projector.parameters():
734
+ p.requires_grad = False
735
+
736
+ if model.config.region_encoder_lr is None:
737
+ for p in model.get_model().region_encoder.parameters():
738
+ p.requires_grad = False
739
+
740
+ if model.config.sam_decoder_lr is None:
741
+ for p in model.grounding_encoder.sam2_model.sam_mask_decoder.parameters():
742
+ p.requires_grad = False
743
+ else:
744
+ for p in model.grounding_encoder.sam2_model.sam_mask_decoder.parameters():
745
+ p.requires_grad = True
746
+
747
+ if model.config.sam_encoder_lr is None:
748
+ for p in model.grounding_encoder.sam2_model.image_encoder.parameters():
749
+ p.requires_grad = False
750
+
751
+ if training_args.lora_enable:
752
+ for n, p in model.named_parameters():
753
+ if any(
754
+ [
755
+ x in n
756
+ for x in ["lm_head", "embed_tokens", "text_hidden_fcs"]
757
+ ]
758
+ ):
759
+ # print(n)
760
+ p.requires_grad = True
761
+
762
+ model.config.max_frames = getattr(data_args, 'max_frames', NUM_FRAMES)
763
+ model.config.image_aspect_ratio = data_args.image_aspect_ratio if 'qwen2vl' not in model_args.vision_encoder else 'qwen2vl'
764
+
765
+ # NOTE: complement data_args via model hyperparameters
766
+ # 1. acquire image size
767
+ model.config.image_size = data_args.image_size = vision_encoder.image_size
768
+ # 2. calculate the number of tokens in the image
769
+ model.config.image_token_length = data_args.image_token_length = mm_projector.cal_proj_size(vision_encoder.num_patches_per_side)
770
+ # 3. check if alignment
771
+ model.config.is_alignment = training_args.is_alignment = data_args.is_alignment = (
772
+ model.config.mm_projector_lr is not None and
773
+ model.config.llm_lr is None and
774
+ model.config.vision_encoder_lr is None
775
+ )
776
+ # 4. set spatial merge size as default
777
+ model.config.spatial_merge_size = data_args.spatial_merge_size = model_args.spatial_merge_size
778
+ tokenizer.add_tokens([DEFAULT_IMAGE_TOKEN, STREAM_START_TOKEN, STREAM_END_TOKEN], special_tokens=True)
779
+ tokenizer.add_tokens([REGION_TOKEN], special_tokens=True)
780
+ num_new_tokens = tokenizer.add_tokens([SEG_TOKEN], special_tokens=True)
781
+ model.resize_token_embeddings(len(tokenizer))
782
+
783
+ model.config.image_token_index = tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_TOKEN)
784
+ model.config.region_token_index = tokenizer.convert_tokens_to_ids(REGION_TOKEN)
785
+ model.config.seg_token_index = tokenizer.convert_tokens_to_ids(SEG_TOKEN)
786
+
787
+ vlprocessor = Videollama3Qwen2Processor(vision_encoder.image_processor, tokenizer)
788
+
789
+ if training_args.bits in [4, 8]:
790
+ from peft.tuners.lora import LoraLayer
791
+ for name, module in model.named_modules():
792
+ if isinstance(module, LoraLayer):
793
+ if training_args.bf16:
794
+ module = module.to(torch.bfloat16)
795
+ if 'norm' in name:
796
+ module = module.to(torch.float32)
797
+ if 'lm_head' in name or 'embed_tokens' in name:
798
+ if hasattr(module, 'weight'):
799
+ if training_args.bf16 and module.weight.dtype == torch.float32:
800
+ module = module.to(torch.bfloat16)
801
+
802
+ if local_rank == 0:
803
+ print("Current model:", model)
804
+ print("Model config:", model.config)
805
+
806
+
807
+ data_module = make_supervised_data_module(vlprocessor=vlprocessor, data_args=data_args)
808
+
809
+ # select a Trainer
810
+ trainer = RynnECTrainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
811
+
812
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
813
+ trainer.train(resume_from_checkpoint=True)
814
+ else:
815
+ trainer.train()
816
+ trainer.save_state()
817
+
818
+ model.config.use_cache = True
819
+
820
+ if training_args.lora_enable:
821
+ state_dict = get_peft_state_maybe_zero_3(model.named_parameters(), training_args.lora_bias)
822
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(model.named_parameters())
823
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
824
+ model.config.save_pretrained(training_args.output_dir)
825
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
826
+ torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
827
+ else:
828
+ safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
829
+
830
+
831
+ if __name__ == "__main__":
832
+ train(attn_implementation="flash_attention_2")
third_parts/sam2/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from hydra import initialize_config_module
8
+
9
+ initialize_config_module("third_parts.sam2.sam2_configs", version_base="1.2")
third_parts/sam2/automatic_mask_generator.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
13
+
14
+ from third_parts.sam2.modeling.sam2_base import SAM2Base
15
+ from third_parts.sam2.sam2_image_predictor import SAM2ImagePredictor
16
+ from third_parts.sam2.utils.amg import (
17
+ area_from_rle,
18
+ batch_iterator,
19
+ batched_mask_to_box,
20
+ box_xyxy_to_xywh,
21
+ build_all_layer_point_grids,
22
+ calculate_stability_score,
23
+ coco_encode_rle,
24
+ generate_crop_boxes,
25
+ is_box_near_crop_edge,
26
+ mask_to_rle_pytorch,
27
+ MaskData,
28
+ remove_small_regions,
29
+ rle_to_mask,
30
+ uncrop_boxes_xyxy,
31
+ uncrop_masks,
32
+ uncrop_points,
33
+ )
34
+
35
+
36
+ class SAM2AutomaticMaskGenerator:
37
+ def __init__(
38
+ self,
39
+ model: SAM2Base,
40
+ points_per_side: Optional[int] = 32,
41
+ points_per_batch: int = 64,
42
+ pred_iou_thresh: float = 0.8,
43
+ stability_score_thresh: float = 0.95,
44
+ stability_score_offset: float = 1.0,
45
+ mask_threshold: float = 0.0,
46
+ box_nms_thresh: float = 0.7,
47
+ crop_n_layers: int = 0,
48
+ crop_nms_thresh: float = 0.7,
49
+ crop_overlap_ratio: float = 512 / 1500,
50
+ crop_n_points_downscale_factor: int = 1,
51
+ point_grids: Optional[List[np.ndarray]] = None,
52
+ min_mask_region_area: int = 0,
53
+ output_mode: str = "binary_mask",
54
+ use_m2m: bool = False,
55
+ multimask_output: bool = True,
56
+ ) -> None:
57
+ """
58
+ Using a SAM 2 model, generates masks for the entire image.
59
+ Generates a grid of point prompts over the image, then filters
60
+ low quality and duplicate masks. The default settings are chosen
61
+ for SAM 2 with a HieraL backbone.
62
+
63
+ Arguments:
64
+ model (Sam): The SAM 2 model to use for mask prediction.
65
+ points_per_side (int or None): The number of points to be sampled
66
+ along one side of the image. The total number of points is
67
+ points_per_side**2. If None, 'point_grids' must provide explicit
68
+ point sampling.
69
+ points_per_batch (int): Sets the number of points run simultaneously
70
+ by the model. Higher numbers may be faster but use more GPU memory.
71
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
72
+ model's predicted mask quality.
73
+ stability_score_thresh (float): A filtering threshold in [0,1], using
74
+ the stability of the mask under changes to the cutoff used to binarize
75
+ the model's mask predictions.
76
+ stability_score_offset (float): The amount to shift the cutoff when
77
+ calculated the stability score.
78
+ mask_threshold (float): Threshold for binarizing the mask logits
79
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
80
+ suppression to filter duplicate masks.
81
+ crop_n_layers (int): If >0, mask prediction will be run again on
82
+ crops of the image. Sets the number of layers to run, where each
83
+ layer has 2**i_layer number of image crops.
84
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
85
+ suppression to filter duplicate masks between different crops.
86
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
87
+ In the first crop layer, crops will overlap by this fraction of
88
+ the image length. Later layers with more crops scale down this overlap.
89
+ crop_n_points_downscale_factor (int): The number of points-per-side
90
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
91
+ point_grids (list(np.ndarray) or None): A list over explicit grids
92
+ of points used for sampling, normalized to [0,1]. The nth grid in the
93
+ list is used in the nth crop layer. Exclusive with points_per_side.
94
+ min_mask_region_area (int): If >0, postprocessing will be applied
95
+ to remove disconnected regions and holes in masks with area smaller
96
+ than min_mask_region_area. Requires opencv.
97
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
98
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
99
+ For large resolutions, 'binary_mask' may consume large amounts of
100
+ memory.
101
+ use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
102
+ multimask_output (bool): Whether to output multimask at each point of the grid.
103
+ """
104
+
105
+ assert (points_per_side is None) != (
106
+ point_grids is None
107
+ ), "Exactly one of points_per_side or point_grid must be provided."
108
+ if points_per_side is not None:
109
+ self.point_grids = build_all_layer_point_grids(
110
+ points_per_side,
111
+ crop_n_layers,
112
+ crop_n_points_downscale_factor,
113
+ )
114
+ elif point_grids is not None:
115
+ self.point_grids = point_grids
116
+ else:
117
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
118
+
119
+ assert output_mode in [
120
+ "binary_mask",
121
+ "uncompressed_rle",
122
+ "coco_rle",
123
+ ], f"Unknown output_mode {output_mode}."
124
+ if output_mode == "coco_rle":
125
+ try:
126
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
127
+ except ImportError as e:
128
+ print("Please install pycocotools")
129
+ raise e
130
+
131
+ self.predictor = SAM2ImagePredictor(
132
+ model,
133
+ max_hole_area=min_mask_region_area,
134
+ max_sprinkle_area=min_mask_region_area,
135
+ )
136
+ self.points_per_batch = points_per_batch
137
+ self.pred_iou_thresh = pred_iou_thresh
138
+ self.stability_score_thresh = stability_score_thresh
139
+ self.stability_score_offset = stability_score_offset
140
+ self.mask_threshold = mask_threshold
141
+ self.box_nms_thresh = box_nms_thresh
142
+ self.crop_n_layers = crop_n_layers
143
+ self.crop_nms_thresh = crop_nms_thresh
144
+ self.crop_overlap_ratio = crop_overlap_ratio
145
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
146
+ self.min_mask_region_area = min_mask_region_area
147
+ self.output_mode = output_mode
148
+ self.use_m2m = use_m2m
149
+ self.multimask_output = multimask_output
150
+
151
+ @torch.no_grad()
152
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
153
+ """
154
+ Generates masks for the given image.
155
+
156
+ Arguments:
157
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
158
+
159
+ Returns:
160
+ list(dict(str, any)): A list over records for masks. Each record is
161
+ a dict containing the following keys:
162
+ segmentation (dict(str, any) or np.ndarray): The mask. If
163
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
164
+ is a dictionary containing the RLE.
165
+ bbox (list(float)): The box around the mask, in XYWH format.
166
+ area (int): The area in pixels of the mask.
167
+ predicted_iou (float): The model's own prediction of the mask's
168
+ quality. This is filtered by the pred_iou_thresh parameter.
169
+ point_coords (list(list(float))): The point coordinates input
170
+ to the model to generate this mask.
171
+ stability_score (float): A measure of the mask's quality. This
172
+ is filtered on using the stability_score_thresh parameter.
173
+ crop_box (list(float)): The crop of the image used to generate
174
+ the mask, given in XYWH format.
175
+ """
176
+
177
+ # Generate masks
178
+ mask_data = self._generate_masks(image)
179
+
180
+ # Encode masks
181
+ if self.output_mode == "coco_rle":
182
+ mask_data["segmentations"] = [
183
+ coco_encode_rle(rle) for rle in mask_data["rles"]
184
+ ]
185
+ elif self.output_mode == "binary_mask":
186
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
187
+ else:
188
+ mask_data["segmentations"] = mask_data["rles"]
189
+
190
+ # Write mask records
191
+ curr_anns = []
192
+ for idx in range(len(mask_data["segmentations"])):
193
+ ann = {
194
+ "segmentation": mask_data["segmentations"][idx],
195
+ "area": area_from_rle(mask_data["rles"][idx]),
196
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
197
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
198
+ "point_coords": [mask_data["points"][idx].tolist()],
199
+ "stability_score": mask_data["stability_score"][idx].item(),
200
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
201
+ }
202
+ curr_anns.append(ann)
203
+
204
+ return curr_anns
205
+
206
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
207
+ orig_size = image.shape[:2]
208
+ crop_boxes, layer_idxs = generate_crop_boxes(
209
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
210
+ )
211
+
212
+ # Iterate over image crops
213
+ data = MaskData()
214
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
215
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
216
+ data.cat(crop_data)
217
+
218
+ # Remove duplicate masks between crops
219
+ if len(crop_boxes) > 1:
220
+ # Prefer masks from smaller crops
221
+ scores = 1 / box_area(data["crop_boxes"])
222
+ scores = scores.to(data["boxes"].device)
223
+ keep_by_nms = batched_nms(
224
+ data["boxes"].float(),
225
+ scores,
226
+ torch.zeros_like(data["boxes"][:, 0]), # categories
227
+ iou_threshold=self.crop_nms_thresh,
228
+ )
229
+ data.filter(keep_by_nms)
230
+ data.to_numpy()
231
+ return data
232
+
233
+ def _process_crop(
234
+ self,
235
+ image: np.ndarray,
236
+ crop_box: List[int],
237
+ crop_layer_idx: int,
238
+ orig_size: Tuple[int, ...],
239
+ ) -> MaskData:
240
+ # Crop the image and calculate embeddings
241
+ x0, y0, x1, y1 = crop_box
242
+ cropped_im = image[y0:y1, x0:x1, :]
243
+ cropped_im_size = cropped_im.shape[:2]
244
+ self.predictor.set_image(cropped_im)
245
+
246
+ # Get points for this crop
247
+ points_scale = np.array(cropped_im_size)[None, ::-1]
248
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
249
+
250
+ # Generate masks for this crop in batches
251
+ data = MaskData()
252
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
253
+ batch_data = self._process_batch(
254
+ points, cropped_im_size, crop_box, orig_size, normalize=True
255
+ )
256
+ data.cat(batch_data)
257
+ del batch_data
258
+ self.predictor.reset_predictor()
259
+
260
+ # Remove duplicates within this crop.
261
+ keep_by_nms = batched_nms(
262
+ data["boxes"].float(),
263
+ data["iou_preds"],
264
+ torch.zeros_like(data["boxes"][:, 0]), # categories
265
+ iou_threshold=self.box_nms_thresh,
266
+ )
267
+ data.filter(keep_by_nms)
268
+
269
+ # Return to the original image frame
270
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
271
+ data["points"] = uncrop_points(data["points"], crop_box)
272
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
273
+
274
+ return data
275
+
276
+ def _process_batch(
277
+ self,
278
+ points: np.ndarray,
279
+ im_size: Tuple[int, ...],
280
+ crop_box: List[int],
281
+ orig_size: Tuple[int, ...],
282
+ normalize=False,
283
+ ) -> MaskData:
284
+ orig_h, orig_w = orig_size
285
+
286
+ # Run model on this batch
287
+ points = torch.as_tensor(points, device=self.predictor.device)
288
+ in_points = self.predictor._transforms.transform_coords(
289
+ points, normalize=normalize, orig_hw=im_size
290
+ )
291
+ in_labels = torch.ones(
292
+ in_points.shape[0], dtype=torch.int, device=in_points.device
293
+ )
294
+ masks, iou_preds, low_res_masks = self.predictor._predict(
295
+ in_points[:, None, :],
296
+ in_labels[:, None],
297
+ multimask_output=self.multimask_output,
298
+ return_logits=True,
299
+ )
300
+
301
+ # Serialize predictions and store in MaskData
302
+ data = MaskData(
303
+ masks=masks.flatten(0, 1),
304
+ iou_preds=iou_preds.flatten(0, 1),
305
+ points=points.repeat_interleave(masks.shape[1], dim=0),
306
+ low_res_masks=low_res_masks.flatten(0, 1),
307
+ )
308
+ del masks
309
+
310
+ if not self.use_m2m:
311
+ # Filter by predicted IoU
312
+ if self.pred_iou_thresh > 0.0:
313
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
314
+ data.filter(keep_mask)
315
+
316
+ # Calculate and filter by stability score
317
+ data["stability_score"] = calculate_stability_score(
318
+ data["masks"], self.mask_threshold, self.stability_score_offset
319
+ )
320
+ if self.stability_score_thresh > 0.0:
321
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
322
+ data.filter(keep_mask)
323
+ else:
324
+ # One step refinement using previous mask predictions
325
+ in_points = self.predictor._transforms.transform_coords(
326
+ data["points"], normalize=normalize, orig_hw=im_size
327
+ )
328
+ labels = torch.ones(
329
+ in_points.shape[0], dtype=torch.int, device=in_points.device
330
+ )
331
+ masks, ious = self.refine_with_m2m(
332
+ in_points, labels, data["low_res_masks"], self.points_per_batch
333
+ )
334
+ data["masks"] = masks.squeeze(1)
335
+ data["iou_preds"] = ious.squeeze(1)
336
+
337
+ if self.pred_iou_thresh > 0.0:
338
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
339
+ data.filter(keep_mask)
340
+
341
+ data["stability_score"] = calculate_stability_score(
342
+ data["masks"], self.mask_threshold, self.stability_score_offset
343
+ )
344
+ if self.stability_score_thresh > 0.0:
345
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
346
+ data.filter(keep_mask)
347
+
348
+ # Threshold masks and calculate boxes
349
+ data["masks"] = data["masks"] > self.mask_threshold
350
+ data["boxes"] = batched_mask_to_box(data["masks"])
351
+
352
+ # Filter boxes that touch crop boundaries
353
+ keep_mask = ~is_box_near_crop_edge(
354
+ data["boxes"], crop_box, [0, 0, orig_w, orig_h]
355
+ )
356
+ if not torch.all(keep_mask):
357
+ data.filter(keep_mask)
358
+
359
+ # Compress to RLE
360
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
361
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
362
+ del data["masks"]
363
+
364
+ return data
365
+
366
+ @staticmethod
367
+ def postprocess_small_regions(
368
+ mask_data: MaskData, min_area: int, nms_thresh: float
369
+ ) -> MaskData:
370
+ """
371
+ Removes small disconnected regions and holes in masks, then reruns
372
+ box NMS to remove any new duplicates.
373
+
374
+ Edits mask_data in place.
375
+
376
+ Requires open-cv as a dependency.
377
+ """
378
+ if len(mask_data["rles"]) == 0:
379
+ return mask_data
380
+
381
+ # Filter small disconnected regions and holes
382
+ new_masks = []
383
+ scores = []
384
+ for rle in mask_data["rles"]:
385
+ mask = rle_to_mask(rle)
386
+
387
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
388
+ unchanged = not changed
389
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
390
+ unchanged = unchanged and not changed
391
+
392
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
393
+ # Give score=0 to changed masks and score=1 to unchanged masks
394
+ # so NMS will prefer ones that didn't need postprocessing
395
+ scores.append(float(unchanged))
396
+
397
+ # Recalculate boxes and remove any new duplicates
398
+ masks = torch.cat(new_masks, dim=0)
399
+ boxes = batched_mask_to_box(masks)
400
+ keep_by_nms = batched_nms(
401
+ boxes.float(),
402
+ torch.as_tensor(scores),
403
+ torch.zeros_like(boxes[:, 0]), # categories
404
+ iou_threshold=nms_thresh,
405
+ )
406
+
407
+ # Only recalculate RLEs for masks that have changed
408
+ for i_mask in keep_by_nms:
409
+ if scores[i_mask] == 0.0:
410
+ mask_torch = masks[i_mask].unsqueeze(0)
411
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
412
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
413
+ mask_data.filter(keep_by_nms)
414
+
415
+ return mask_data
416
+
417
+ def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch):
418
+ new_masks = []
419
+ new_iou_preds = []
420
+
421
+ for cur_points, cur_point_labels, low_res_mask in batch_iterator(
422
+ points_per_batch, points, point_labels, low_res_masks
423
+ ):
424
+ best_masks, best_iou_preds, _ = self.predictor._predict(
425
+ cur_points[:, None, :],
426
+ cur_point_labels[:, None],
427
+ mask_input=low_res_mask[:, None, :],
428
+ multimask_output=False,
429
+ return_logits=True,
430
+ )
431
+ new_masks.append(best_masks)
432
+ new_iou_preds.append(best_iou_preds)
433
+ masks = torch.cat(new_masks, dim=0)
434
+ return masks, torch.cat(new_iou_preds, dim=0)
third_parts/sam2/build_sam.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+
9
+ import torch
10
+ from hydra import compose
11
+ from hydra.utils import instantiate
12
+ from omegaconf import OmegaConf
13
+
14
+
15
+ def build_sam2(
16
+ config_file,
17
+ ckpt_path=None,
18
+ device="cuda",
19
+ mode="eval",
20
+ hydra_overrides_extra=[],
21
+ apply_postprocessing=True,
22
+ ):
23
+
24
+ if apply_postprocessing:
25
+ hydra_overrides_extra = hydra_overrides_extra.copy()
26
+ hydra_overrides_extra += [
27
+ # dynamically fall back to multi-mask if the single mask is not stable
28
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
29
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
30
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
31
+ ]
32
+ # Read config and init model
33
+ cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
34
+ OmegaConf.resolve(cfg)
35
+ model = instantiate(cfg.model, _recursive_=True)
36
+ _load_checkpoint(model, ckpt_path)
37
+ model = model.to(device)
38
+ if mode == "eval":
39
+ model.eval()
40
+ return model
41
+
42
+
43
+ def build_sam2_video_predictor(
44
+ config_file,
45
+ ckpt_path=None,
46
+ device="cuda",
47
+ mode="eval",
48
+ hydra_overrides_extra=[],
49
+ apply_postprocessing=True,
50
+ ):
51
+ hydra_overrides = [
52
+ "++model._target_=third_parts.sam2.sam2_video_predictor.SAM2VideoPredictor",
53
+ ]
54
+ if apply_postprocessing:
55
+ hydra_overrides_extra = hydra_overrides_extra.copy()
56
+ hydra_overrides_extra += [
57
+ # dynamically fall back to multi-mask if the single mask is not stable
58
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
59
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
60
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
61
+ # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
62
+ "++model.binarize_mask_from_pts_for_mem_enc=true",
63
+ # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
64
+ "++model.fill_hole_area=8",
65
+ ]
66
+ hydra_overrides.extend(hydra_overrides_extra)
67
+
68
+ # Read config and init model
69
+ cfg = compose(config_name=config_file, overrides=hydra_overrides)
70
+ OmegaConf.resolve(cfg)
71
+ model = instantiate(cfg.model, _recursive_=True)
72
+ _load_checkpoint(model, ckpt_path)
73
+ model = model.to(device)
74
+ if mode == "eval":
75
+ model.eval()
76
+ return model
77
+
78
+
79
+ def _load_checkpoint(model, ckpt_path):
80
+ if ckpt_path is not None:
81
+ sd = torch.load(ckpt_path, map_location="cpu")["model"]
82
+ missing_keys, unexpected_keys = model.load_state_dict(sd)
83
+ if missing_keys:
84
+ logging.error(missing_keys)
85
+ raise RuntimeError()
86
+ if unexpected_keys:
87
+ logging.error(unexpected_keys)
88
+ raise RuntimeError()
89
+ logging.info("Loaded checkpoint sucessfully")
third_parts/sam2/csrc/connected_components.cu ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ // adapted from https://github.com/zsef123/Connected_components_PyTorch
8
+ // with license found in the LICENSE_cctorch file in the root directory.
9
+ #include <ATen/cuda/CUDAContext.h>
10
+ #include <cuda.h>
11
+ #include <cuda_runtime.h>
12
+ #include <torch/extension.h>
13
+ #include <torch/script.h>
14
+ #include <vector>
15
+
16
+ // 2d
17
+ #define BLOCK_ROWS 16
18
+ #define BLOCK_COLS 16
19
+
20
+ namespace cc2d {
21
+
22
+ template <typename T>
23
+ __device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) {
24
+ return (bitmap >> pos) & 1;
25
+ }
26
+
27
+ __device__ int32_t find(const int32_t* s_buf, int32_t n) {
28
+ while (s_buf[n] != n)
29
+ n = s_buf[n];
30
+ return n;
31
+ }
32
+
33
+ __device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) {
34
+ const int32_t id = n;
35
+ while (s_buf[n] != n) {
36
+ n = s_buf[n];
37
+ s_buf[id] = n;
38
+ }
39
+ return n;
40
+ }
41
+
42
+ __device__ void union_(int32_t* s_buf, int32_t a, int32_t b) {
43
+ bool done;
44
+ do {
45
+ a = find(s_buf, a);
46
+ b = find(s_buf, b);
47
+
48
+ if (a < b) {
49
+ int32_t old = atomicMin(s_buf + b, a);
50
+ done = (old == b);
51
+ b = old;
52
+ } else if (b < a) {
53
+ int32_t old = atomicMin(s_buf + a, b);
54
+ done = (old == a);
55
+ a = old;
56
+ } else
57
+ done = true;
58
+
59
+ } while (!done);
60
+ }
61
+
62
+ __global__ void
63
+ init_labeling(int32_t* label, const uint32_t W, const uint32_t H) {
64
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
65
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
66
+ const uint32_t idx = row * W + col;
67
+
68
+ if (row < H && col < W)
69
+ label[idx] = idx;
70
+ }
71
+
72
+ __global__ void
73
+ merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) {
74
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
75
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
76
+ const uint32_t idx = row * W + col;
77
+
78
+ if (row >= H || col >= W)
79
+ return;
80
+
81
+ uint32_t P = 0;
82
+
83
+ if (img[idx])
84
+ P |= 0x777;
85
+ if (row + 1 < H && img[idx + W])
86
+ P |= 0x777 << 4;
87
+ if (col + 1 < W && img[idx + 1])
88
+ P |= 0x777 << 1;
89
+
90
+ if (col == 0)
91
+ P &= 0xEEEE;
92
+ if (col + 1 >= W)
93
+ P &= 0x3333;
94
+ else if (col + 2 >= W)
95
+ P &= 0x7777;
96
+
97
+ if (row == 0)
98
+ P &= 0xFFF0;
99
+ if (row + 1 >= H)
100
+ P &= 0xFF;
101
+
102
+ if (P > 0) {
103
+ // If need check about top-left pixel(if flag the first bit) and hit the
104
+ // top-left pixel
105
+ if (hasBit(P, 0) && img[idx - W - 1]) {
106
+ union_(label, idx, idx - 2 * W - 2); // top left block
107
+ }
108
+
109
+ if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1]))
110
+ union_(label, idx, idx - 2 * W); // top bottom block
111
+
112
+ if (hasBit(P, 3) && img[idx + 2 - W])
113
+ union_(label, idx, idx - 2 * W + 2); // top right block
114
+
115
+ if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1]))
116
+ union_(label, idx, idx - 2); // just left block
117
+ }
118
+ }
119
+
120
+ __global__ void compression(int32_t* label, const int32_t W, const int32_t H) {
121
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
122
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
123
+ const uint32_t idx = row * W + col;
124
+
125
+ if (row < H && col < W)
126
+ find_n_compress(label, idx);
127
+ }
128
+
129
+ __global__ void final_labeling(
130
+ const uint8_t* img,
131
+ int32_t* label,
132
+ const int32_t W,
133
+ const int32_t H) {
134
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
135
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
136
+ const uint32_t idx = row * W + col;
137
+
138
+ if (row >= H || col >= W)
139
+ return;
140
+
141
+ int32_t y = label[idx] + 1;
142
+
143
+ if (img[idx])
144
+ label[idx] = y;
145
+ else
146
+ label[idx] = 0;
147
+
148
+ if (col + 1 < W) {
149
+ if (img[idx + 1])
150
+ label[idx + 1] = y;
151
+ else
152
+ label[idx + 1] = 0;
153
+
154
+ if (row + 1 < H) {
155
+ if (img[idx + W + 1])
156
+ label[idx + W + 1] = y;
157
+ else
158
+ label[idx + W + 1] = 0;
159
+ }
160
+ }
161
+
162
+ if (row + 1 < H) {
163
+ if (img[idx + W])
164
+ label[idx + W] = y;
165
+ else
166
+ label[idx + W] = 0;
167
+ }
168
+ }
169
+
170
+ __global__ void init_counting(
171
+ const int32_t* label,
172
+ int32_t* count_init,
173
+ const int32_t W,
174
+ const int32_t H) {
175
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
176
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
177
+ const uint32_t idx = row * W + col;
178
+
179
+ if (row >= H || col >= W)
180
+ return;
181
+
182
+ int32_t y = label[idx];
183
+ if (y > 0) {
184
+ int32_t count_idx = y - 1;
185
+ atomicAdd(count_init + count_idx, 1);
186
+ }
187
+ }
188
+
189
+ __global__ void final_counting(
190
+ const int32_t* label,
191
+ const int32_t* count_init,
192
+ int32_t* count_final,
193
+ const int32_t W,
194
+ const int32_t H) {
195
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
196
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
197
+ const uint32_t idx = row * W + col;
198
+
199
+ if (row >= H || col >= W)
200
+ return;
201
+
202
+ int32_t y = label[idx];
203
+ if (y > 0) {
204
+ int32_t count_idx = y - 1;
205
+ count_final[idx] = count_init[count_idx];
206
+ } else {
207
+ count_final[idx] = 0;
208
+ }
209
+ }
210
+
211
+ } // namespace cc2d
212
+
213
+ std::vector<torch::Tensor> get_connected_componnets(
214
+ const torch::Tensor& inputs) {
215
+ AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor");
216
+ AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape");
217
+ AT_ASSERTM(
218
+ inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type");
219
+
220
+ const uint32_t N = inputs.size(0);
221
+ const uint32_t C = inputs.size(1);
222
+ const uint32_t H = inputs.size(2);
223
+ const uint32_t W = inputs.size(3);
224
+
225
+ AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape");
226
+ AT_ASSERTM((H % 2) == 0, "height must be an even number");
227
+ AT_ASSERTM((W % 2) == 0, "width must be an even number");
228
+
229
+ // label must be uint32_t
230
+ auto label_options =
231
+ torch::TensorOptions().dtype(torch::kInt32).device(inputs.device());
232
+ torch::Tensor labels = torch::zeros({N, C, H, W}, label_options);
233
+ torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options);
234
+ torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options);
235
+
236
+ dim3 grid = dim3(
237
+ ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS,
238
+ ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS);
239
+ dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS);
240
+ dim3 grid_count =
241
+ dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS);
242
+ dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS);
243
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
244
+
245
+ for (int n = 0; n < N; n++) {
246
+ uint32_t offset = n * H * W;
247
+
248
+ cc2d::init_labeling<<<grid, block, 0, stream>>>(
249
+ labels.data_ptr<int32_t>() + offset, W, H);
250
+ cc2d::merge<<<grid, block, 0, stream>>>(
251
+ inputs.data_ptr<uint8_t>() + offset,
252
+ labels.data_ptr<int32_t>() + offset,
253
+ W,
254
+ H);
255
+ cc2d::compression<<<grid, block, 0, stream>>>(
256
+ labels.data_ptr<int32_t>() + offset, W, H);
257
+ cc2d::final_labeling<<<grid, block, 0, stream>>>(
258
+ inputs.data_ptr<uint8_t>() + offset,
259
+ labels.data_ptr<int32_t>() + offset,
260
+ W,
261
+ H);
262
+
263
+ // get the counting of each pixel
264
+ cc2d::init_counting<<<grid_count, block_count, 0, stream>>>(
265
+ labels.data_ptr<int32_t>() + offset,
266
+ counts_init.data_ptr<int32_t>() + offset,
267
+ W,
268
+ H);
269
+ cc2d::final_counting<<<grid_count, block_count, 0, stream>>>(
270
+ labels.data_ptr<int32_t>() + offset,
271
+ counts_init.data_ptr<int32_t>() + offset,
272
+ counts_final.data_ptr<int32_t>() + offset,
273
+ W,
274
+ H);
275
+ }
276
+
277
+ // returned values are [labels, counts]
278
+ std::vector<torch::Tensor> outputs;
279
+ outputs.push_back(labels);
280
+ outputs.push_back(counts_final);
281
+ return outputs;
282
+ }
283
+
284
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
285
+ m.def(
286
+ "get_connected_componnets",
287
+ &get_connected_componnets,
288
+ "get_connected_componnets");
289
+ }
third_parts/sam2/modeling/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
third_parts/sam2/modeling/backbones/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
third_parts/sam2/modeling/backbones/hieradet.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from functools import partial
8
+ from typing import List, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from third_parts.sam2.modeling.backbones.utils import (
15
+ PatchEmbed,
16
+ window_partition,
17
+ window_unpartition,
18
+ )
19
+
20
+ from third_parts.sam2.modeling.sam2_utils import DropPath, MLP
21
+
22
+
23
+ def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
24
+ if pool is None:
25
+ return x
26
+ # (B, H, W, C) -> (B, C, H, W)
27
+ x = x.permute(0, 3, 1, 2)
28
+ x = pool(x)
29
+ # (B, C, H', W') -> (B, H', W', C)
30
+ x = x.permute(0, 2, 3, 1)
31
+ if norm:
32
+ x = norm(x)
33
+
34
+ return x
35
+
36
+
37
+ class MultiScaleAttention(nn.Module):
38
+ def __init__(
39
+ self,
40
+ dim: int,
41
+ dim_out: int,
42
+ num_heads: int,
43
+ q_pool: nn.Module = None,
44
+ ):
45
+ super().__init__()
46
+
47
+ self.dim = dim
48
+ self.dim_out = dim_out
49
+
50
+ self.num_heads = num_heads
51
+ head_dim = dim_out // num_heads
52
+ self.scale = head_dim**-0.5
53
+
54
+ self.q_pool = q_pool
55
+ self.qkv = nn.Linear(dim, dim_out * 3)
56
+ self.proj = nn.Linear(dim_out, dim_out)
57
+
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ B, H, W, _ = x.shape
60
+ # qkv with shape (B, H * W, 3, nHead, C)
61
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
62
+ # q, k, v with shape (B, H * W, nheads, C)
63
+ q, k, v = torch.unbind(qkv, 2)
64
+
65
+ # Q pooling (for downsample at stage changes)
66
+ if self.q_pool:
67
+ q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
68
+ H, W = q.shape[1:3] # downsampled shape
69
+ q = q.reshape(B, H * W, self.num_heads, -1)
70
+
71
+ # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
72
+ x = F.scaled_dot_product_attention(
73
+ q.transpose(1, 2),
74
+ k.transpose(1, 2),
75
+ v.transpose(1, 2),
76
+ )
77
+ # Transpose back
78
+ x = x.transpose(1, 2)
79
+ x = x.reshape(B, H, W, -1)
80
+
81
+ x = self.proj(x)
82
+
83
+ return x
84
+
85
+
86
+ class MultiScaleBlock(nn.Module):
87
+ def __init__(
88
+ self,
89
+ dim: int,
90
+ dim_out: int,
91
+ num_heads: int,
92
+ mlp_ratio: float = 4.0,
93
+ drop_path: float = 0.0,
94
+ norm_layer: Union[nn.Module, str] = "LayerNorm",
95
+ q_stride: Tuple[int, int] = None,
96
+ act_layer: nn.Module = nn.GELU,
97
+ window_size: int = 0,
98
+ ):
99
+ super().__init__()
100
+
101
+ if isinstance(norm_layer, str):
102
+ norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
103
+
104
+ self.dim = dim
105
+ self.dim_out = dim_out
106
+ self.norm1 = norm_layer(dim)
107
+
108
+ self.window_size = window_size
109
+
110
+ self.pool, self.q_stride = None, q_stride
111
+ if self.q_stride:
112
+ self.pool = nn.MaxPool2d(
113
+ kernel_size=q_stride, stride=q_stride, ceil_mode=False
114
+ )
115
+
116
+ self.attn = MultiScaleAttention(
117
+ dim,
118
+ dim_out,
119
+ num_heads=num_heads,
120
+ q_pool=self.pool,
121
+ )
122
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
123
+
124
+ self.norm2 = norm_layer(dim_out)
125
+ self.mlp = MLP(
126
+ dim_out,
127
+ int(dim_out * mlp_ratio),
128
+ dim_out,
129
+ num_layers=2,
130
+ activation=act_layer,
131
+ )
132
+
133
+ if dim != dim_out:
134
+ self.proj = nn.Linear(dim, dim_out)
135
+
136
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
137
+ shortcut = x # B, H, W, C
138
+ x = self.norm1(x)
139
+
140
+ # Skip connection
141
+ if self.dim != self.dim_out:
142
+ shortcut = do_pool(self.proj(x), self.pool)
143
+
144
+ # Window partition
145
+ window_size = self.window_size
146
+ if window_size > 0:
147
+ H, W = x.shape[1], x.shape[2]
148
+ x, pad_hw = window_partition(x, window_size)
149
+
150
+ # Window Attention + Q Pooling (if stage change)
151
+ x = self.attn(x)
152
+ if self.q_stride:
153
+ # Shapes have changed due to Q pooling
154
+ window_size = self.window_size // self.q_stride[0]
155
+ H, W = shortcut.shape[1:3]
156
+
157
+ pad_h = (window_size - H % window_size) % window_size
158
+ pad_w = (window_size - W % window_size) % window_size
159
+ pad_hw = (H + pad_h, W + pad_w)
160
+
161
+ # Reverse window partition
162
+ if self.window_size > 0:
163
+ x = window_unpartition(x, window_size, pad_hw, (H, W))
164
+
165
+ x = shortcut + self.drop_path(x)
166
+ # MLP
167
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
168
+ return x
169
+
170
+
171
+ class Hiera(nn.Module):
172
+ """
173
+ Reference: https://arxiv.org/abs/2306.00989
174
+ """
175
+
176
+ def __init__(
177
+ self,
178
+ embed_dim: int = 96, # initial embed dim
179
+ num_heads: int = 1, # initial number of heads
180
+ drop_path_rate: float = 0.0, # stochastic depth
181
+ q_pool: int = 3, # number of q_pool stages
182
+ q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
183
+ stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
184
+ dim_mul: float = 2.0, # dim_mul factor at stage shift
185
+ head_mul: float = 2.0, # head_mul factor at stage shift
186
+ window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
187
+ # window size per stage, when not using global att.
188
+ window_spec: Tuple[int, ...] = (
189
+ 8,
190
+ 4,
191
+ 14,
192
+ 7,
193
+ ),
194
+ # global attn in these blocks
195
+ global_att_blocks: Tuple[int, ...] = (
196
+ 12,
197
+ 16,
198
+ 20,
199
+ ),
200
+ return_interm_layers=True, # return feats from every stage
201
+ ):
202
+ super().__init__()
203
+
204
+ assert len(stages) == len(window_spec)
205
+ self.window_spec = window_spec
206
+
207
+ depth = sum(stages)
208
+ self.q_stride = q_stride
209
+ self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
210
+ assert 0 <= q_pool <= len(self.stage_ends[:-1])
211
+ self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
212
+ self.return_interm_layers = return_interm_layers
213
+
214
+ self.patch_embed = PatchEmbed(
215
+ embed_dim=embed_dim,
216
+ )
217
+ # Which blocks have global att?
218
+ self.global_att_blocks = global_att_blocks
219
+
220
+ # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
221
+ self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
222
+ self.pos_embed = nn.Parameter(
223
+ torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
224
+ )
225
+ self.pos_embed_window = nn.Parameter(
226
+ torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
227
+ )
228
+
229
+ dpr = [
230
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
231
+ ] # stochastic depth decay rule
232
+
233
+ cur_stage = 1
234
+ self.blocks = nn.ModuleList()
235
+
236
+ for i in range(depth):
237
+ dim_out = embed_dim
238
+ # lags by a block, so first block of
239
+ # next stage uses an initial window size
240
+ # of previous stage and final window size of current stage
241
+ window_size = self.window_spec[cur_stage - 1]
242
+
243
+ if self.global_att_blocks is not None:
244
+ window_size = 0 if i in self.global_att_blocks else window_size
245
+
246
+ if i - 1 in self.stage_ends:
247
+ dim_out = int(embed_dim * dim_mul)
248
+ num_heads = int(num_heads * head_mul)
249
+ cur_stage += 1
250
+
251
+ block = MultiScaleBlock(
252
+ dim=embed_dim,
253
+ dim_out=dim_out,
254
+ num_heads=num_heads,
255
+ drop_path=dpr[i],
256
+ q_stride=self.q_stride if i in self.q_pool_blocks else None,
257
+ window_size=window_size,
258
+ )
259
+
260
+ embed_dim = dim_out
261
+ self.blocks.append(block)
262
+
263
+ self.channel_list = (
264
+ [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
265
+ if return_interm_layers
266
+ else [self.blocks[-1].dim_out]
267
+ )
268
+
269
+ def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
270
+ h, w = hw
271
+ window_embed = self.pos_embed_window
272
+ pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
273
+ pos_embed = pos_embed + window_embed.tile(
274
+ [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
275
+ )
276
+ pos_embed = pos_embed.permute(0, 2, 3, 1)
277
+ return pos_embed
278
+
279
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
280
+ x = self.patch_embed(x)
281
+ # x: (B, H, W, C)
282
+
283
+ # Add pos embed
284
+ x = x + self._get_pos_embed(x.shape[1:3])
285
+
286
+ outputs = []
287
+ for i, blk in enumerate(self.blocks):
288
+ x = blk(x)
289
+ if (i == self.stage_ends[-1]) or (
290
+ i in self.stage_ends and self.return_interm_layers
291
+ ):
292
+ feats = x.permute(0, 3, 1, 2)
293
+ outputs.append(feats)
294
+
295
+ return outputs
third_parts/sam2/modeling/backbones/image_encoder.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class ImageEncoder(nn.Module):
15
+ def __init__(
16
+ self,
17
+ trunk: nn.Module,
18
+ neck: nn.Module,
19
+ scalp: int = 0,
20
+ ):
21
+ super().__init__()
22
+ self.trunk = trunk
23
+ self.neck = neck
24
+ self.scalp = scalp
25
+ assert (
26
+ self.trunk.channel_list == self.neck.backbone_channel_list
27
+ ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
28
+
29
+ def forward(self, sample: torch.Tensor):
30
+ # Forward through backbone
31
+ features, pos = self.neck(self.trunk(sample))
32
+ if self.scalp > 0:
33
+ # Discard the lowest resolution features
34
+ features, pos = features[: -self.scalp], pos[: -self.scalp]
35
+
36
+ src = features[-1]
37
+ output = {
38
+ "vision_features": src,
39
+ "vision_pos_enc": pos,
40
+ "backbone_fpn": features,
41
+ }
42
+ return output
43
+
44
+
45
+ class FpnNeck(nn.Module):
46
+ """
47
+ A modified variant of Feature Pyramid Network (FPN) neck
48
+ (we remove output conv and also do bicubic interpolation similar to ViT
49
+ pos embed interpolation)
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ position_encoding: nn.Module,
55
+ d_model: int,
56
+ backbone_channel_list: List[int],
57
+ kernel_size: int = 1,
58
+ stride: int = 1,
59
+ padding: int = 0,
60
+ fpn_interp_model: str = "bilinear",
61
+ fuse_type: str = "sum",
62
+ fpn_top_down_levels: Optional[List[int]] = None,
63
+ ):
64
+ """Initialize the neck
65
+ :param trunk: the backbone
66
+ :param position_encoding: the positional encoding to use
67
+ :param d_model: the dimension of the model
68
+ :param neck_norm: the normalization to use
69
+ """
70
+ super().__init__()
71
+ self.position_encoding = position_encoding
72
+ self.convs = nn.ModuleList()
73
+ self.backbone_channel_list = backbone_channel_list
74
+ for dim in backbone_channel_list:
75
+ current = nn.Sequential()
76
+ current.add_module(
77
+ "conv",
78
+ nn.Conv2d(
79
+ in_channels=dim,
80
+ out_channels=d_model,
81
+ kernel_size=kernel_size,
82
+ stride=stride,
83
+ padding=padding,
84
+ ),
85
+ )
86
+
87
+ self.convs.append(current)
88
+ self.fpn_interp_model = fpn_interp_model
89
+ assert fuse_type in ["sum", "avg"]
90
+ self.fuse_type = fuse_type
91
+
92
+ # levels to have top-down features in its outputs
93
+ # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
94
+ # have top-down propagation, while outputs of level 0 and level 1 have only
95
+ # lateral features from the same backbone level.
96
+ if fpn_top_down_levels is None:
97
+ # default is to have top-down features on all levels
98
+ fpn_top_down_levels = range(len(self.convs))
99
+ self.fpn_top_down_levels = list(fpn_top_down_levels)
100
+
101
+ def forward(self, xs: List[torch.Tensor]):
102
+
103
+ out = [None] * len(self.convs)
104
+ pos = [None] * len(self.convs)
105
+ assert len(xs) == len(self.convs)
106
+ # fpn forward pass
107
+ # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
108
+ prev_features = None
109
+ # forward in top-down order (from low to high resolution)
110
+ n = len(self.convs) - 1
111
+ for i in range(n, -1, -1):
112
+ x = xs[i]
113
+ lateral_features = self.convs[n - i](x)
114
+ if i in self.fpn_top_down_levels and prev_features is not None:
115
+ top_down_features = F.interpolate(
116
+ prev_features.to(dtype=torch.float32),
117
+ scale_factor=2.0,
118
+ mode=self.fpn_interp_model,
119
+ align_corners=(
120
+ None if self.fpn_interp_model == "nearest" else False
121
+ ),
122
+ antialias=False,
123
+ )
124
+ prev_features = lateral_features + top_down_features
125
+ if self.fuse_type == "avg":
126
+ prev_features /= 2
127
+ else:
128
+ prev_features = lateral_features
129
+ x_out = prev_features
130
+ out[i] = x_out
131
+ pos[i] = self.position_encoding(x_out).to(x_out.dtype)
132
+
133
+ return out, pos
third_parts/sam2/modeling/backbones/utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Some utilities for backbones, in particular for windowing"""
8
+
9
+ from typing import Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ def window_partition(x, window_size):
17
+ """
18
+ Partition into non-overlapping windows with padding if needed.
19
+ Args:
20
+ x (tensor): input tokens with [B, H, W, C].
21
+ window_size (int): window size.
22
+ Returns:
23
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
24
+ (Hp, Wp): padded height and width before partition
25
+ """
26
+ B, H, W, C = x.shape
27
+
28
+ pad_h = (window_size - H % window_size) % window_size
29
+ pad_w = (window_size - W % window_size) % window_size
30
+ if pad_h > 0 or pad_w > 0:
31
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
32
+ Hp, Wp = H + pad_h, W + pad_w
33
+
34
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
35
+ windows = (
36
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
37
+ )
38
+ return windows, (Hp, Wp)
39
+
40
+
41
+ def window_unpartition(windows, window_size, pad_hw, hw):
42
+ """
43
+ Window unpartition into original sequences and removing padding.
44
+ Args:
45
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
46
+ window_size (int): window size.
47
+ pad_hw (Tuple): padded height and width (Hp, Wp).
48
+ hw (Tuple): original height and width (H, W) before padding.
49
+ Returns:
50
+ x: unpartitioned sequences with [B, H, W, C].
51
+ """
52
+ Hp, Wp = pad_hw
53
+ H, W = hw
54
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
55
+ x = windows.view(
56
+ B, Hp // window_size, Wp // window_size, window_size, window_size, -1
57
+ )
58
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
59
+
60
+ if Hp > H or Wp > W:
61
+ x = x[:, :H, :W, :].contiguous()
62
+ return x
63
+
64
+
65
+ class PatchEmbed(nn.Module):
66
+ """
67
+ Image to Patch Embedding.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ kernel_size: Tuple[int, ...] = (7, 7),
73
+ stride: Tuple[int, ...] = (4, 4),
74
+ padding: Tuple[int, ...] = (3, 3),
75
+ in_chans: int = 3,
76
+ embed_dim: int = 768,
77
+ ):
78
+ """
79
+ Args:
80
+ kernel_size (Tuple): kernel size of the projection layer.
81
+ stride (Tuple): stride of the projection layer.
82
+ padding (Tuple): padding size of the projection layer.
83
+ in_chans (int): Number of input image channels.
84
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
85
+ """
86
+ super().__init__()
87
+ self.proj = nn.Conv2d(
88
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
89
+ )
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ x = self.proj(x)
93
+ # B C H W -> B H W C
94
+ x = x.permute(0, 2, 3, 1)
95
+ return x
third_parts/sam2/modeling/memory_attention.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from torch import nn, Tensor
11
+
12
+ from third_parts.sam2.modeling.sam.transformer import RoPEAttention
13
+
14
+ from third_parts.sam2.modeling.sam2_utils import get_activation_fn, get_clones
15
+
16
+
17
+ class MemoryAttentionLayer(nn.Module):
18
+
19
+ def __init__(
20
+ self,
21
+ activation: str,
22
+ cross_attention: nn.Module,
23
+ d_model: int,
24
+ dim_feedforward: int,
25
+ dropout: float,
26
+ pos_enc_at_attn: bool,
27
+ pos_enc_at_cross_attn_keys: bool,
28
+ pos_enc_at_cross_attn_queries: bool,
29
+ self_attention: nn.Module,
30
+ ):
31
+ super().__init__()
32
+ self.d_model = d_model
33
+ self.dim_feedforward = dim_feedforward
34
+ self.dropout_value = dropout
35
+ self.self_attn = self_attention
36
+ self.cross_attn_image = cross_attention
37
+
38
+ # Implementation of Feedforward model
39
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
40
+ self.dropout = nn.Dropout(dropout)
41
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
42
+
43
+ self.norm1 = nn.LayerNorm(d_model)
44
+ self.norm2 = nn.LayerNorm(d_model)
45
+ self.norm3 = nn.LayerNorm(d_model)
46
+ self.dropout1 = nn.Dropout(dropout)
47
+ self.dropout2 = nn.Dropout(dropout)
48
+ self.dropout3 = nn.Dropout(dropout)
49
+
50
+ self.activation_str = activation
51
+ self.activation = get_activation_fn(activation)
52
+
53
+ # Where to add pos enc
54
+ self.pos_enc_at_attn = pos_enc_at_attn
55
+ self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
56
+ self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
57
+
58
+ def _forward_sa(self, tgt, query_pos):
59
+ # Self-Attention
60
+ tgt2 = self.norm1(tgt)
61
+ q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
62
+ tgt2 = self.self_attn(q, k, v=tgt2)
63
+ tgt = tgt + self.dropout1(tgt2)
64
+ return tgt
65
+
66
+ def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
67
+ kwds = {}
68
+ if num_k_exclude_rope > 0:
69
+ assert isinstance(self.cross_attn_image, RoPEAttention)
70
+ kwds = {"num_k_exclude_rope": num_k_exclude_rope}
71
+
72
+ # Cross-Attention
73
+ tgt2 = self.norm2(tgt)
74
+ tgt2 = self.cross_attn_image(
75
+ q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
76
+ k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
77
+ v=memory,
78
+ **kwds,
79
+ )
80
+ tgt = tgt + self.dropout2(tgt2)
81
+ return tgt
82
+
83
+ def forward(
84
+ self,
85
+ tgt,
86
+ memory,
87
+ pos: Optional[Tensor] = None,
88
+ query_pos: Optional[Tensor] = None,
89
+ num_k_exclude_rope: int = 0,
90
+ ) -> torch.Tensor:
91
+
92
+ # Self-Attn, Cross-Attn
93
+ tgt = self._forward_sa(tgt, query_pos)
94
+ tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
95
+ # MLP
96
+ tgt2 = self.norm3(tgt)
97
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
98
+ tgt = tgt + self.dropout3(tgt2)
99
+ return tgt
100
+
101
+
102
+ class MemoryAttention(nn.Module):
103
+ def __init__(
104
+ self,
105
+ d_model: int,
106
+ pos_enc_at_input: bool,
107
+ layer: nn.Module,
108
+ num_layers: int,
109
+ batch_first: bool = True, # Do layers expect batch first input?
110
+ ):
111
+ super().__init__()
112
+ self.d_model = d_model
113
+ self.layers = get_clones(layer, num_layers)
114
+ self.num_layers = num_layers
115
+ self.norm = nn.LayerNorm(d_model)
116
+ self.pos_enc_at_input = pos_enc_at_input
117
+ self.batch_first = batch_first
118
+
119
+ def forward(
120
+ self,
121
+ curr: torch.Tensor, # self-attention inputs
122
+ memory: torch.Tensor, # cross-attention inputs
123
+ curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
124
+ memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
125
+ num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
126
+ ):
127
+ if isinstance(curr, list):
128
+ assert isinstance(curr_pos, list)
129
+ assert len(curr) == len(curr_pos) == 1
130
+ curr, curr_pos = (
131
+ curr[0],
132
+ curr_pos[0],
133
+ )
134
+
135
+ assert (
136
+ curr.shape[1] == memory.shape[1]
137
+ ), "Batch size must be the same for curr and memory"
138
+
139
+ output = curr
140
+ if self.pos_enc_at_input and curr_pos is not None:
141
+ output = output + 0.1 * curr_pos
142
+
143
+ if self.batch_first:
144
+ # Convert to batch first
145
+ output = output.transpose(0, 1)
146
+ curr_pos = curr_pos.transpose(0, 1)
147
+ memory = memory.transpose(0, 1)
148
+ memory_pos = memory_pos.transpose(0, 1)
149
+
150
+ for layer in self.layers:
151
+ kwds = {}
152
+ if isinstance(layer.cross_attn_image, RoPEAttention):
153
+ kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
154
+
155
+ output = layer(
156
+ tgt=output,
157
+ memory=memory,
158
+ pos=memory_pos,
159
+ query_pos=curr_pos,
160
+ **kwds,
161
+ )
162
+ normed_output = self.norm(output)
163
+
164
+ if self.batch_first:
165
+ # Convert back to seq first
166
+ normed_output = normed_output.transpose(0, 1)
167
+ curr_pos = curr_pos.transpose(0, 1)
168
+
169
+ return normed_output
third_parts/sam2/modeling/memory_encoder.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from third_parts.sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
15
+
16
+
17
+ class MaskDownSampler(nn.Module):
18
+ """
19
+ Progressively downsample a mask by total_stride, each time by stride.
20
+ Note that LayerNorm is applied per *token*, like in ViT.
21
+
22
+ With each downsample (by a factor stride**2), channel capacity increases by the same factor.
23
+ In the end, we linearly project to embed_dim channels.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ embed_dim=256,
29
+ kernel_size=4,
30
+ stride=4,
31
+ padding=0,
32
+ total_stride=16,
33
+ activation=nn.GELU,
34
+ ):
35
+ super().__init__()
36
+ num_layers = int(math.log2(total_stride) // math.log2(stride))
37
+ assert stride**num_layers == total_stride
38
+ self.encoder = nn.Sequential()
39
+ mask_in_chans, mask_out_chans = 1, 1
40
+ for _ in range(num_layers):
41
+ mask_out_chans = mask_in_chans * (stride**2)
42
+ self.encoder.append(
43
+ nn.Conv2d(
44
+ mask_in_chans,
45
+ mask_out_chans,
46
+ kernel_size=kernel_size,
47
+ stride=stride,
48
+ padding=padding,
49
+ )
50
+ )
51
+ self.encoder.append(LayerNorm2d(mask_out_chans))
52
+ self.encoder.append(activation())
53
+ mask_in_chans = mask_out_chans
54
+
55
+ self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
56
+
57
+ def forward(self, x):
58
+ return self.encoder(x)
59
+
60
+
61
+ # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
62
+ class CXBlock(nn.Module):
63
+ r"""ConvNeXt Block. There are two equivalent implementations:
64
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
65
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
66
+ We use (2) as we find it slightly faster in PyTorch
67
+
68
+ Args:
69
+ dim (int): Number of input channels.
70
+ drop_path (float): Stochastic depth rate. Default: 0.0
71
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ dim,
77
+ kernel_size=7,
78
+ padding=3,
79
+ drop_path=0.0,
80
+ layer_scale_init_value=1e-6,
81
+ use_dwconv=True,
82
+ ):
83
+ super().__init__()
84
+ self.dwconv = nn.Conv2d(
85
+ dim,
86
+ dim,
87
+ kernel_size=kernel_size,
88
+ padding=padding,
89
+ groups=dim if use_dwconv else 1,
90
+ ) # depthwise conv
91
+ self.norm = LayerNorm2d(dim, eps=1e-6)
92
+ self.pwconv1 = nn.Linear(
93
+ dim, 4 * dim
94
+ ) # pointwise/1x1 convs, implemented with linear layers
95
+ self.act = nn.GELU()
96
+ self.pwconv2 = nn.Linear(4 * dim, dim)
97
+ self.gamma = (
98
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
99
+ if layer_scale_init_value > 0
100
+ else None
101
+ )
102
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
103
+
104
+ def forward(self, x):
105
+ input = x
106
+ x = self.dwconv(x)
107
+ x = self.norm(x)
108
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
109
+ x = self.pwconv1(x)
110
+ x = self.act(x)
111
+ x = self.pwconv2(x)
112
+ if self.gamma is not None:
113
+ x = self.gamma * x
114
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
115
+
116
+ x = input + self.drop_path(x)
117
+ return x
118
+
119
+
120
+ class Fuser(nn.Module):
121
+ def __init__(self, layer, num_layers, dim=None, input_projection=False):
122
+ super().__init__()
123
+ self.proj = nn.Identity()
124
+ self.layers = get_clones(layer, num_layers)
125
+
126
+ if input_projection:
127
+ assert dim is not None
128
+ self.proj = nn.Conv2d(dim, dim, kernel_size=1)
129
+
130
+ def forward(self, x):
131
+ # normally x: (N, C, H, W)
132
+ x = self.proj(x)
133
+ for layer in self.layers:
134
+ x = layer(x)
135
+ return x
136
+
137
+
138
+ class MemoryEncoder(nn.Module):
139
+ def __init__(
140
+ self,
141
+ out_dim,
142
+ mask_downsampler,
143
+ fuser,
144
+ position_encoding,
145
+ in_dim=256, # in_dim of pix_feats
146
+ ):
147
+ super().__init__()
148
+
149
+ self.mask_downsampler = mask_downsampler
150
+
151
+ self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
152
+ self.fuser = fuser
153
+ self.position_encoding = position_encoding
154
+ self.out_proj = nn.Identity()
155
+ if out_dim != in_dim:
156
+ self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
157
+
158
+ def forward(
159
+ self,
160
+ pix_feat: torch.Tensor,
161
+ masks: torch.Tensor,
162
+ skip_mask_sigmoid: bool = False,
163
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
164
+ ## Process masks
165
+ # sigmoid, so that less domain shift from gt masks which are bool
166
+ if not skip_mask_sigmoid:
167
+ masks = F.sigmoid(masks)
168
+ masks = self.mask_downsampler(masks)
169
+
170
+ ## Fuse pix_feats and downsampled masks
171
+ # in case the visual features are on CPU, cast them to CUDA
172
+ pix_feat = pix_feat.to(masks.device)
173
+
174
+ x = self.pix_feat_proj(pix_feat)
175
+ x = x + masks
176
+ x = self.fuser(x)
177
+ x = self.out_proj(x)
178
+
179
+ pos = self.position_encoding(x).to(x.dtype)
180
+
181
+ return {"vision_features": x, "vision_pos_enc": [pos]}
third_parts/sam2/modeling/position_encoding.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Any, Optional, Tuple
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+
16
+ class PositionEmbeddingSine(nn.Module):
17
+ """
18
+ This is a more standard version of the position embedding, very similar to the one
19
+ used by the Attention is all you need paper, generalized to work on images.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ num_pos_feats,
25
+ temperature: int = 10000,
26
+ normalize: bool = True,
27
+ scale: Optional[float] = None,
28
+ ):
29
+ super().__init__()
30
+ assert num_pos_feats % 2 == 0, "Expecting even model width"
31
+ self.num_pos_feats = num_pos_feats // 2
32
+ self.temperature = temperature
33
+ self.normalize = normalize
34
+ if scale is not None and normalize is False:
35
+ raise ValueError("normalize should be True if scale is passed")
36
+ if scale is None:
37
+ scale = 2 * math.pi
38
+ self.scale = scale
39
+
40
+ self.cache = {}
41
+
42
+ def _encode_xy(self, x, y):
43
+ # The positions are expected to be normalized
44
+ assert len(x) == len(y) and x.ndim == y.ndim == 1
45
+ x_embed = x * self.scale
46
+ y_embed = y * self.scale
47
+
48
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
49
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
50
+
51
+ pos_x = x_embed[:, None] / dim_t
52
+ pos_y = y_embed[:, None] / dim_t
53
+ pos_x = torch.stack(
54
+ (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
55
+ ).flatten(1)
56
+ pos_y = torch.stack(
57
+ (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
58
+ ).flatten(1)
59
+ return pos_x, pos_y
60
+
61
+ @torch.no_grad()
62
+ def encode_boxes(self, x, y, w, h):
63
+ pos_x, pos_y = self._encode_xy(x, y)
64
+ pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
65
+ return pos
66
+
67
+ encode = encode_boxes # Backwards compatibility
68
+
69
+ @torch.no_grad()
70
+ def encode_points(self, x, y, labels):
71
+ (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
72
+ assert bx == by and nx == ny and bx == bl and nx == nl
73
+ pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
74
+ pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
75
+ pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
76
+ return pos
77
+
78
+ @torch.no_grad()
79
+ def forward(self, x: torch.Tensor):
80
+ cache_key = (x.shape[-2], x.shape[-1])
81
+ if cache_key in self.cache:
82
+ return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
83
+ y_embed = (
84
+ torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
85
+ .view(1, -1, 1)
86
+ .repeat(x.shape[0], 1, x.shape[-1])
87
+ )
88
+ x_embed = (
89
+ torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
90
+ .view(1, 1, -1)
91
+ .repeat(x.shape[0], x.shape[-2], 1)
92
+ )
93
+
94
+ if self.normalize:
95
+ eps = 1e-6
96
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
97
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
98
+
99
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
100
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
101
+
102
+ pos_x = x_embed[:, :, :, None] / dim_t
103
+ pos_y = y_embed[:, :, :, None] / dim_t
104
+ pos_x = torch.stack(
105
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
106
+ ).flatten(3)
107
+ pos_y = torch.stack(
108
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
109
+ ).flatten(3)
110
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
111
+ self.cache[cache_key] = pos[0]
112
+ return pos
113
+
114
+
115
+ class PositionEmbeddingRandom(nn.Module):
116
+ """
117
+ Positional encoding using random spatial frequencies.
118
+ """
119
+
120
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
121
+ super().__init__()
122
+ if scale is None or scale <= 0.0:
123
+ scale = 1.0
124
+ self.register_buffer(
125
+ "positional_encoding_gaussian_matrix",
126
+ scale * torch.randn((2, num_pos_feats)),
127
+ )
128
+ self.first = True
129
+
130
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
131
+ """Positionally encode points that are normalized to [0,1]."""
132
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
133
+ coords = 2 * coords - 1
134
+ coords = coords.to(self.positional_encoding_gaussian_matrix.dtype)
135
+ if self.first:
136
+ self.positional_encoding_gaussian_matrix = self.positional_encoding_gaussian_matrix.to(coords.device)
137
+ self.first = False
138
+ coords = coords @ self.positional_encoding_gaussian_matrix
139
+ coords = 2 * np.pi * coords
140
+ # outputs d_1 x ... x d_n x C shape
141
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
142
+
143
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
144
+ """Generate positional encoding for a grid of the specified size."""
145
+ h, w = size
146
+ device: Any = self.positional_encoding_gaussian_matrix.device
147
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
148
+ y_embed = grid.cumsum(dim=0) - 0.5
149
+ x_embed = grid.cumsum(dim=1) - 0.5
150
+ y_embed = y_embed / h
151
+ x_embed = x_embed / w
152
+
153
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
154
+ return pe.permute(2, 0, 1) # C x H x W
155
+
156
+ def forward_with_coords(
157
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
158
+ ) -> torch.Tensor:
159
+ """Positionally encode points that are not normalized to [0,1]."""
160
+ coords = coords_input.clone()
161
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
162
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
163
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
164
+
165
+
166
+ # Rotary Positional Encoding, adapted from:
167
+ # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
168
+ # 2. https://github.com/naver-ai/rope-vit
169
+ # 3. https://github.com/lucidrains/rotary-embedding-torch
170
+
171
+
172
+ def init_t_xy(end_x: int, end_y: int):
173
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
174
+ t_x = (t % end_x).float()
175
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
176
+ return t_x, t_y
177
+
178
+
179
+ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
180
+ freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
181
+ freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
182
+
183
+ t_x, t_y = init_t_xy(end_x, end_y)
184
+ freqs_x = torch.outer(t_x, freqs_x)
185
+ freqs_y = torch.outer(t_y, freqs_y)
186
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
187
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
188
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
189
+
190
+
191
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
192
+ ndim = x.ndim
193
+ assert 0 <= 1 < ndim
194
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
195
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
196
+ return freqs_cis.view(*shape)
197
+
198
+
199
+ def apply_rotary_enc(
200
+ xq: torch.Tensor,
201
+ xk: torch.Tensor,
202
+ freqs_cis: torch.Tensor,
203
+ repeat_freqs_k: bool = False,
204
+ ):
205
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
206
+ xk_ = (
207
+ torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
208
+ if xk.shape[-2] != 0
209
+ else None
210
+ )
211
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
212
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
213
+ if xk_ is None:
214
+ # no keys to rotate, due to dropout
215
+ return xq_out.type_as(xq).to(xq.device), xk
216
+ # repeat freqs along seq_len dim to match k seq_len
217
+ if repeat_freqs_k:
218
+ r = xk_.shape[-2] // xq_.shape[-2]
219
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
220
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
221
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
third_parts/sam2/modeling/sam/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
third_parts/sam2/modeling/sam/mask_decoder.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional, Tuple, Type
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from third_parts.sam2.modeling.sam2_utils import LayerNorm2d, MLP
13
+
14
+
15
+ class MaskDecoder(nn.Module):
16
+ def __init__(
17
+ self,
18
+ *,
19
+ transformer_dim: int,
20
+ transformer: nn.Module,
21
+ num_multimask_outputs: int = 3,
22
+ activation: Type[nn.Module] = nn.GELU,
23
+ iou_head_depth: int = 3,
24
+ iou_head_hidden_dim: int = 256,
25
+ use_high_res_features: bool = False,
26
+ iou_prediction_use_sigmoid=False,
27
+ dynamic_multimask_via_stability=False,
28
+ dynamic_multimask_stability_delta=0.05,
29
+ dynamic_multimask_stability_thresh=0.98,
30
+ pred_obj_scores: bool = False,
31
+ pred_obj_scores_mlp: bool = False,
32
+ use_multimask_token_for_obj_ptr: bool = False,
33
+ ) -> None:
34
+ """
35
+ Predicts masks given an image and prompt embeddings, using a
36
+ transformer architecture.
37
+
38
+ Arguments:
39
+ transformer_dim (int): the channel dimension of the transformer
40
+ transformer (nn.Module): the transformer used to predict masks
41
+ num_multimask_outputs (int): the number of masks to predict
42
+ when disambiguating masks
43
+ activation (nn.Module): the type of activation to use when
44
+ upscaling masks
45
+ iou_head_depth (int): the depth of the MLP used to predict
46
+ mask quality
47
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
48
+ used to predict mask quality
49
+ """
50
+ super().__init__()
51
+ self.transformer_dim = transformer_dim
52
+ self.transformer = transformer
53
+
54
+ self.num_multimask_outputs = num_multimask_outputs
55
+
56
+ self.iou_token = nn.Embedding(1, transformer_dim)
57
+ self.num_mask_tokens = num_multimask_outputs + 1
58
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
59
+
60
+ self.pred_obj_scores = pred_obj_scores
61
+ if self.pred_obj_scores:
62
+ self.obj_score_token = nn.Embedding(1, transformer_dim)
63
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
64
+
65
+ self.output_upscaling = nn.Sequential(
66
+ nn.ConvTranspose2d(
67
+ transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
68
+ ),
69
+ LayerNorm2d(transformer_dim // 4),
70
+ activation(),
71
+ nn.ConvTranspose2d(
72
+ transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
73
+ ),
74
+ activation(),
75
+ )
76
+ self.use_high_res_features = use_high_res_features
77
+ if use_high_res_features:
78
+ self.conv_s0 = nn.Conv2d(
79
+ transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
80
+ )
81
+ self.conv_s1 = nn.Conv2d(
82
+ transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
83
+ )
84
+
85
+ self.output_hypernetworks_mlps = nn.ModuleList(
86
+ [
87
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
88
+ for i in range(self.num_mask_tokens)
89
+ ]
90
+ )
91
+
92
+ self.iou_prediction_head = MLP(
93
+ transformer_dim,
94
+ iou_head_hidden_dim,
95
+ self.num_mask_tokens,
96
+ iou_head_depth,
97
+ sigmoid_output=iou_prediction_use_sigmoid,
98
+ )
99
+ if self.pred_obj_scores:
100
+ self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
101
+ if pred_obj_scores_mlp:
102
+ self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
103
+
104
+ # When outputting a single mask, optionally we can dynamically fall back to the best
105
+ # multimask output token if the single mask output token gives low stability scores.
106
+ self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
107
+ self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
108
+ self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
109
+
110
+ def forward(
111
+ self,
112
+ image_embeddings: torch.Tensor,
113
+ image_pe: torch.Tensor,
114
+ sparse_prompt_embeddings: torch.Tensor,
115
+ dense_prompt_embeddings: torch.Tensor,
116
+ multimask_output: bool,
117
+ repeat_image: bool,
118
+ high_res_features: Optional[List[torch.Tensor]] = None,
119
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
120
+ """
121
+ Predict masks given image and prompt embeddings.
122
+
123
+ Arguments:
124
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
125
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
126
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
127
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
128
+ multimask_output (bool): Whether to return multiple masks or a single
129
+ mask.
130
+
131
+ Returns:
132
+ torch.Tensor: batched predicted masks
133
+ torch.Tensor: batched predictions of mask quality
134
+ torch.Tensor: batched SAM token for mask output
135
+ """
136
+ masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
137
+ image_embeddings=image_embeddings,
138
+ image_pe=image_pe,
139
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
140
+ dense_prompt_embeddings=dense_prompt_embeddings,
141
+ repeat_image=repeat_image,
142
+ high_res_features=high_res_features,
143
+ )
144
+
145
+ # Select the correct mask or masks for output
146
+ if multimask_output:
147
+ masks = masks[:, 1:, :, :]
148
+ iou_pred = iou_pred[:, 1:]
149
+ elif self.dynamic_multimask_via_stability and not self.training:
150
+ masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
151
+ else:
152
+ masks = masks[:, 0:1, :, :]
153
+ iou_pred = iou_pred[:, 0:1]
154
+
155
+ if multimask_output and self.use_multimask_token_for_obj_ptr:
156
+ sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
157
+ else:
158
+ # Take the mask output token. Here we *always* use the token for single mask output.
159
+ # At test time, even if we track after 1-click (and using multimask_output=True),
160
+ # we still take the single mask token here. The rationale is that we always track
161
+ # after multiple clicks during training, so the past tokens seen during training
162
+ # are always the single mask token (and we'll let it be the object-memory token).
163
+ sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
164
+
165
+ # Prepare output
166
+ return masks, iou_pred, sam_tokens_out, object_score_logits
167
+
168
+ def predict_masks(
169
+ self,
170
+ image_embeddings: torch.Tensor,
171
+ image_pe: torch.Tensor,
172
+ sparse_prompt_embeddings: torch.Tensor,
173
+ dense_prompt_embeddings: torch.Tensor,
174
+ repeat_image: bool,
175
+ high_res_features: Optional[List[torch.Tensor]] = None,
176
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
177
+ """Predicts masks. See 'forward' for more details."""
178
+ # Concatenate output tokens
179
+ s = 0
180
+ if self.pred_obj_scores:
181
+ output_tokens = torch.cat(
182
+ [
183
+ self.obj_score_token.weight,
184
+ self.iou_token.weight,
185
+ self.mask_tokens.weight,
186
+ ],
187
+ dim=0,
188
+ )
189
+ s = 1
190
+ else:
191
+ output_tokens = torch.cat(
192
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0
193
+ )
194
+ output_tokens = output_tokens.unsqueeze(0).expand(
195
+ sparse_prompt_embeddings.size(0), -1, -1
196
+ )
197
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
198
+
199
+ # Expand per-image data in batch direction to be per-mask
200
+ if repeat_image:
201
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
202
+ else:
203
+ assert image_embeddings.shape[0] == tokens.shape[0]
204
+ src = image_embeddings
205
+ src = src + dense_prompt_embeddings
206
+ assert (
207
+ image_pe.size(0) == 1
208
+ ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
209
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
210
+ b, c, h, w = src.shape
211
+
212
+ # Run the transformer
213
+ # print('src: ', src.dtype, 'pos_src:', pos_src.dtype, 'tokens:', tokens.dtype)
214
+ _dtype = pos_src.dtype
215
+ src = src.to(_dtype)
216
+ tokens = tokens.to(_dtype)
217
+ hs, src = self.transformer(src, pos_src, tokens)
218
+ iou_token_out = hs[:, s, :]
219
+ mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
220
+
221
+ # Upscale mask embeddings and predict masks using the mask tokens
222
+ src = src.transpose(1, 2).view(b, c, h, w)
223
+ if not self.use_high_res_features:
224
+ upscaled_embedding = self.output_upscaling(src)
225
+ else:
226
+ dc1, ln1, act1, dc2, act2 = self.output_upscaling
227
+ feat_s0, feat_s1 = high_res_features
228
+ upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
229
+ upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
230
+
231
+ hyper_in_list: List[torch.Tensor] = []
232
+ for i in range(self.num_mask_tokens):
233
+ hyper_in_list.append(
234
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
235
+ )
236
+ hyper_in = torch.stack(hyper_in_list, dim=1)
237
+ b, c, h, w = upscaled_embedding.shape
238
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
239
+
240
+ # Generate mask quality predictions
241
+ iou_pred = self.iou_prediction_head(iou_token_out)
242
+ if self.pred_obj_scores:
243
+ assert s == 1
244
+ object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
245
+ else:
246
+ # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
247
+ object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
248
+
249
+ return masks, iou_pred, mask_tokens_out, object_score_logits
250
+
251
+ def _get_stability_scores(self, mask_logits):
252
+ """
253
+ Compute stability scores of the mask logits based on the IoU between upper and
254
+ lower thresholds, similar to https://github.com/fairinternal/onevision/pull/568.
255
+ """
256
+ mask_logits = mask_logits.flatten(-2)
257
+ stability_delta = self.dynamic_multimask_stability_delta
258
+ area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
259
+ area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
260
+ stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
261
+ return stability_scores
262
+
263
+ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
264
+ """
265
+ When outputting a single mask, if the stability score from the current single-mask
266
+ output (based on output token 0) falls below a threshold, we instead select from
267
+ multi-mask outputs (based on output token 1~3) the mask with the highest predicted
268
+ IoU score. This is intended to ensure a valid mask for both clicking and tracking.
269
+ """
270
+ # The best mask from multimask output tokens (1~3)
271
+ multimask_logits = all_mask_logits[:, 1:, :, :]
272
+ multimask_iou_scores = all_iou_scores[:, 1:]
273
+ best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
274
+ batch_inds = torch.arange(
275
+ multimask_iou_scores.size(0), device=all_iou_scores.device
276
+ )
277
+ best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
278
+ best_multimask_logits = best_multimask_logits.unsqueeze(1)
279
+ best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
280
+ best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
281
+
282
+ # The mask from singlemask output token 0 and its stability score
283
+ singlemask_logits = all_mask_logits[:, 0:1, :, :]
284
+ singlemask_iou_scores = all_iou_scores[:, 0:1]
285
+ stability_scores = self._get_stability_scores(singlemask_logits)
286
+ is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
287
+
288
+ # Dynamically fall back to best multimask output upon low stability scores.
289
+ mask_logits_out = torch.where(
290
+ is_stable[..., None, None].expand_as(singlemask_logits),
291
+ singlemask_logits,
292
+ best_multimask_logits,
293
+ )
294
+ iou_scores_out = torch.where(
295
+ is_stable.expand_as(singlemask_iou_scores),
296
+ singlemask_iou_scores,
297
+ best_multimask_iou_scores,
298
+ )
299
+ return mask_logits_out, iou_scores_out
third_parts/sam2/modeling/sam/prompt_encoder.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional, Tuple, Type
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from third_parts.sam2.modeling.position_encoding import PositionEmbeddingRandom
13
+
14
+ from third_parts.sam2.modeling.sam2_utils import LayerNorm2d
15
+
16
+
17
+ class PromptEncoder(nn.Module):
18
+ def __init__(
19
+ self,
20
+ embed_dim: int,
21
+ image_embedding_size: Tuple[int, int],
22
+ input_image_size: Tuple[int, int],
23
+ mask_in_chans: int,
24
+ activation: Type[nn.Module] = nn.GELU,
25
+ ) -> None:
26
+ """
27
+ Encodes prompts for input to SAM's mask decoder.
28
+
29
+ Arguments:
30
+ embed_dim (int): The prompts' embedding dimension
31
+ image_embedding_size (tuple(int, int)): The spatial size of the
32
+ image embedding, as (H, W).
33
+ input_image_size (int): The padded size of the image as input
34
+ to the image encoder, as (H, W).
35
+ mask_in_chans (int): The number of hidden channels used for
36
+ encoding input masks.
37
+ activation (nn.Module): The activation to use when encoding
38
+ input masks.
39
+ """
40
+ super().__init__()
41
+ self.embed_dim = embed_dim
42
+ self.input_image_size = input_image_size
43
+ self.image_embedding_size = image_embedding_size
44
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
45
+
46
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
47
+ point_embeddings = [
48
+ nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
49
+ ]
50
+ self.point_embeddings = nn.ModuleList(point_embeddings)
51
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
52
+
53
+ self.mask_input_size = (
54
+ 4 * image_embedding_size[0],
55
+ 4 * image_embedding_size[1],
56
+ )
57
+ self.mask_downscaling = nn.Sequential(
58
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
59
+ LayerNorm2d(mask_in_chans // 4),
60
+ activation(),
61
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
62
+ LayerNorm2d(mask_in_chans),
63
+ activation(),
64
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
65
+ )
66
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
67
+
68
+ def get_dense_pe(self) -> torch.Tensor:
69
+ """
70
+ Returns the positional encoding used to encode point prompts,
71
+ applied to a dense set of points the shape of the image encoding.
72
+
73
+ Returns:
74
+ torch.Tensor: Positional encoding with shape
75
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
76
+ """
77
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
78
+
79
+ def _embed_points(
80
+ self,
81
+ points: torch.Tensor,
82
+ labels: torch.Tensor,
83
+ pad: bool,
84
+ ) -> torch.Tensor:
85
+ """Embeds point prompts."""
86
+ points = points + 0.5 # Shift to center of pixel
87
+ if pad:
88
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
89
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
90
+ points = torch.cat([points, padding_point], dim=1)
91
+ labels = torch.cat([labels, padding_label], dim=1)
92
+ point_embedding = self.pe_layer.forward_with_coords(
93
+ points, self.input_image_size
94
+ )
95
+ point_embedding[labels == -1] = 0.0
96
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
97
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
98
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
99
+ point_embedding[labels == 2] += self.point_embeddings[2].weight
100
+ point_embedding[labels == 3] += self.point_embeddings[3].weight
101
+ return point_embedding
102
+
103
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
104
+ """Embeds box prompts."""
105
+ boxes = boxes + 0.5 # Shift to center of pixel
106
+ coords = boxes.reshape(-1, 2, 2)
107
+ corner_embedding = self.pe_layer.forward_with_coords(
108
+ coords, self.input_image_size
109
+ )
110
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
111
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
112
+ return corner_embedding
113
+
114
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
115
+ """Embeds mask inputs."""
116
+ mask_embedding = self.mask_downscaling(masks)
117
+ return mask_embedding
118
+
119
+ def _get_batch_size(
120
+ self,
121
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
122
+ boxes: Optional[torch.Tensor],
123
+ masks: Optional[torch.Tensor],
124
+ ) -> int:
125
+ """
126
+ Gets the batch size of the output given the batch size of the input prompts.
127
+ """
128
+ if points is not None:
129
+ return points[0].shape[0]
130
+ elif boxes is not None:
131
+ return boxes.shape[0]
132
+ elif masks is not None:
133
+ return masks.shape[0]
134
+ else:
135
+ return 1
136
+
137
+ def _get_device(self) -> torch.device:
138
+ return self.point_embeddings[0].weight.device
139
+
140
+ def forward(
141
+ self,
142
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
143
+ boxes: Optional[torch.Tensor],
144
+ masks: Optional[torch.Tensor],
145
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
146
+ """
147
+ Embeds different types of prompts, returning both sparse and dense
148
+ embeddings.
149
+
150
+ Arguments:
151
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
152
+ and labels to embed.
153
+ boxes (torch.Tensor or none): boxes to embed
154
+ masks (torch.Tensor or none): masks to embed
155
+
156
+ Returns:
157
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
158
+ BxNx(embed_dim), where N is determined by the number of input points
159
+ and boxes.
160
+ torch.Tensor: dense embeddings for the masks, in the shape
161
+ Bx(embed_dim)x(embed_H)x(embed_W)
162
+ """
163
+ bs = self._get_batch_size(points, boxes, masks)
164
+ sparse_embeddings = torch.empty(
165
+ (bs, 0, self.embed_dim), device=self._get_device()
166
+ )
167
+ if points is not None:
168
+ coords, labels = points
169
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
170
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
171
+ if boxes is not None:
172
+ box_embeddings = self._embed_boxes(boxes)
173
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
174
+
175
+ if masks is not None:
176
+ dense_embeddings = self._embed_masks(masks)
177
+ else:
178
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
179
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
180
+ )
181
+
182
+ return sparse_embeddings, dense_embeddings
third_parts/sam2/modeling/sam/transformer.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import warnings
9
+ from functools import partial
10
+ from typing import Tuple, Type
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import nn, Tensor
15
+
16
+ from third_parts.sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
17
+
18
+ from third_parts.sam2.modeling.sam2_utils import MLP
19
+ from third_parts.sam2.utils.misc import get_sdpa_settings
20
+
21
+ warnings.simplefilter(action="ignore", category=FutureWarning)
22
+ # OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
23
+ OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = True, True, True
24
+
25
+
26
+ class TwoWayTransformer(nn.Module):
27
+ def __init__(
28
+ self,
29
+ depth: int,
30
+ embedding_dim: int,
31
+ num_heads: int,
32
+ mlp_dim: int,
33
+ activation: Type[nn.Module] = nn.ReLU,
34
+ attention_downsample_rate: int = 2,
35
+ ) -> None:
36
+ """
37
+ A transformer decoder that attends to an input image using
38
+ queries whose positional embedding is supplied.
39
+
40
+ Args:
41
+ depth (int): number of layers in the transformer
42
+ embedding_dim (int): the channel dimension for the input embeddings
43
+ num_heads (int): the number of heads for multihead attention. Must
44
+ divide embedding_dim
45
+ mlp_dim (int): the channel dimension internal to the MLP block
46
+ activation (nn.Module): the activation to use in the MLP block
47
+ """
48
+ super().__init__()
49
+ self.depth = depth
50
+ self.embedding_dim = embedding_dim
51
+ self.num_heads = num_heads
52
+ self.mlp_dim = mlp_dim
53
+ self.layers = nn.ModuleList()
54
+
55
+ for i in range(depth):
56
+ self.layers.append(
57
+ TwoWayAttentionBlock(
58
+ embedding_dim=embedding_dim,
59
+ num_heads=num_heads,
60
+ mlp_dim=mlp_dim,
61
+ activation=activation,
62
+ attention_downsample_rate=attention_downsample_rate,
63
+ skip_first_layer_pe=(i == 0),
64
+ )
65
+ )
66
+
67
+ self.final_attn_token_to_image = Attention(
68
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
69
+ )
70
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
71
+
72
+ def forward(
73
+ self,
74
+ image_embedding: Tensor,
75
+ image_pe: Tensor,
76
+ point_embedding: Tensor,
77
+ ) -> Tuple[Tensor, Tensor]:
78
+ """
79
+ Args:
80
+ image_embedding (torch.Tensor): image to attend to. Should be shape
81
+ B x embedding_dim x h x w for any h and w.
82
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
83
+ have the same shape as image_embedding.
84
+ point_embedding (torch.Tensor): the embedding to add to the query points.
85
+ Must have shape B x N_points x embedding_dim for any N_points.
86
+
87
+ Returns:
88
+ torch.Tensor: the processed point_embedding
89
+ torch.Tensor: the processed image_embedding
90
+ """
91
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
92
+ bs, c, h, w = image_embedding.shape
93
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
94
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
95
+
96
+ # Prepare queries
97
+ queries = point_embedding
98
+ keys = image_embedding
99
+
100
+ # Apply transformer blocks and final layernorm
101
+ for layer in self.layers:
102
+ queries, keys = layer(
103
+ queries=queries,
104
+ keys=keys,
105
+ query_pe=point_embedding,
106
+ key_pe=image_pe,
107
+ )
108
+
109
+ # Apply the final attention layer from the points to the image
110
+ q = queries + point_embedding
111
+ k = keys + image_pe
112
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
113
+ queries = queries + attn_out
114
+ queries = self.norm_final_attn(queries)
115
+
116
+ return queries, keys
117
+
118
+
119
+ class TwoWayAttentionBlock(nn.Module):
120
+ def __init__(
121
+ self,
122
+ embedding_dim: int,
123
+ num_heads: int,
124
+ mlp_dim: int = 2048,
125
+ activation: Type[nn.Module] = nn.ReLU,
126
+ attention_downsample_rate: int = 2,
127
+ skip_first_layer_pe: bool = False,
128
+ ) -> None:
129
+ """
130
+ A transformer block with four layers: (1) self-attention of sparse
131
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
132
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
133
+ inputs.
134
+
135
+ Arguments:
136
+ embedding_dim (int): the channel dimension of the embeddings
137
+ num_heads (int): the number of heads in the attention layers
138
+ mlp_dim (int): the hidden dimension of the mlp block
139
+ activation (nn.Module): the activation of the mlp block
140
+ skip_first_layer_pe (bool): skip the PE on the first layer
141
+ """
142
+ super().__init__()
143
+ self.self_attn = Attention(embedding_dim, num_heads)
144
+ self.norm1 = nn.LayerNorm(embedding_dim)
145
+
146
+ self.cross_attn_token_to_image = Attention(
147
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
148
+ )
149
+ self.norm2 = nn.LayerNorm(embedding_dim)
150
+
151
+ self.mlp = MLP(
152
+ embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
153
+ )
154
+ self.norm3 = nn.LayerNorm(embedding_dim)
155
+
156
+ self.norm4 = nn.LayerNorm(embedding_dim)
157
+ self.cross_attn_image_to_token = Attention(
158
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
159
+ )
160
+
161
+ self.skip_first_layer_pe = skip_first_layer_pe
162
+
163
+ def forward(
164
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
165
+ ) -> Tuple[Tensor, Tensor]:
166
+ # Self attention block
167
+ if self.skip_first_layer_pe:
168
+ queries = self.self_attn(q=queries, k=queries, v=queries)
169
+ else:
170
+ q = queries + query_pe
171
+ attn_out = self.self_attn(q=q, k=q, v=queries)
172
+ queries = queries + attn_out
173
+ queries = self.norm1(queries)
174
+
175
+ # Cross attention block, tokens attending to image embedding
176
+ q = queries + query_pe
177
+ k = keys + key_pe
178
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
179
+ queries = queries + attn_out
180
+ queries = self.norm2(queries)
181
+
182
+ # MLP block
183
+ mlp_out = self.mlp(queries)
184
+ queries = queries + mlp_out
185
+ queries = self.norm3(queries)
186
+
187
+ # Cross attention block, image embedding attending to tokens
188
+ q = queries + query_pe
189
+ k = keys + key_pe
190
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
191
+ keys = keys + attn_out
192
+ keys = self.norm4(keys)
193
+
194
+ return queries, keys
195
+
196
+
197
+ class Attention(nn.Module):
198
+ """
199
+ An attention layer that allows for downscaling the size of the embedding
200
+ after projection to queries, keys, and values.
201
+ """
202
+
203
+ def __init__(
204
+ self,
205
+ embedding_dim: int,
206
+ num_heads: int,
207
+ downsample_rate: int = 1,
208
+ dropout: float = 0.0,
209
+ kv_in_dim: int = None,
210
+ ) -> None:
211
+ super().__init__()
212
+ self.embedding_dim = embedding_dim
213
+ self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
214
+ self.internal_dim = embedding_dim // downsample_rate
215
+ self.num_heads = num_heads
216
+ assert (
217
+ self.internal_dim % num_heads == 0
218
+ ), "num_heads must divide embedding_dim."
219
+
220
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
221
+ self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
222
+ self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
223
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
224
+
225
+ self.dropout_p = dropout
226
+
227
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
228
+ b, n, c = x.shape
229
+ x = x.reshape(b, n, num_heads, c // num_heads)
230
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
231
+
232
+ def _recombine_heads(self, x: Tensor) -> Tensor:
233
+ b, n_heads, n_tokens, c_per_head = x.shape
234
+ x = x.transpose(1, 2)
235
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
236
+
237
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
238
+ # Input projections
239
+ q = self.q_proj(q)
240
+ k = self.k_proj(k)
241
+ v = self.v_proj(v)
242
+
243
+ # Separate into heads
244
+ q = self._separate_heads(q, self.num_heads)
245
+ k = self._separate_heads(k, self.num_heads)
246
+ v = self._separate_heads(v, self.num_heads)
247
+
248
+ dropout_p = self.dropout_p if self.training else 0.0
249
+ # Attention
250
+ with torch.backends.cuda.sdp_kernel(
251
+ enable_flash=USE_FLASH_ATTN,
252
+ # if Flash attention kernel is off, then math kernel needs to be enabled
253
+ enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
254
+ enable_mem_efficient=OLD_GPU,
255
+ ):
256
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
257
+
258
+ out = self._recombine_heads(out)
259
+ out = self.out_proj(out)
260
+
261
+ return out
262
+
263
+
264
+ class RoPEAttention(Attention):
265
+ """Attention with rotary position encoding."""
266
+
267
+ def __init__(
268
+ self,
269
+ *args,
270
+ rope_theta=10000.0,
271
+ # whether to repeat q rope to match k length
272
+ # this is needed for cross-attention to memories
273
+ rope_k_repeat=False,
274
+ feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
275
+ **kwargs,
276
+ ):
277
+ super().__init__(*args, **kwargs)
278
+
279
+ self.compute_cis = partial(
280
+ compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
281
+ )
282
+ freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
283
+ self.freqs_cis = freqs_cis
284
+ self.rope_k_repeat = rope_k_repeat
285
+
286
+ def forward(
287
+ self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
288
+ ) -> Tensor:
289
+ # Input projections
290
+ q = self.q_proj(q)
291
+ k = self.k_proj(k)
292
+ v = self.v_proj(v)
293
+
294
+ # Separate into heads
295
+ q = self._separate_heads(q, self.num_heads)
296
+ k = self._separate_heads(k, self.num_heads)
297
+ v = self._separate_heads(v, self.num_heads)
298
+
299
+ # Apply rotary position encoding
300
+ w = h = math.sqrt(q.shape[-2])
301
+ self.freqs_cis = self.freqs_cis.to(q.device)
302
+ if self.freqs_cis.shape[0] != q.shape[-2]:
303
+ self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
304
+ if q.shape[-2] != k.shape[-2]:
305
+ assert self.rope_k_repeat
306
+
307
+ num_k_rope = k.size(-2) - num_k_exclude_rope
308
+ q, k[:, :, :num_k_rope] = apply_rotary_enc(
309
+ q,
310
+ k[:, :, :num_k_rope],
311
+ freqs_cis=self.freqs_cis,
312
+ repeat_freqs_k=self.rope_k_repeat,
313
+ )
314
+
315
+ dropout_p = self.dropout_p if self.training else 0.0
316
+ # Attention
317
+ with torch.backends.cuda.sdp_kernel(
318
+ enable_flash=USE_FLASH_ATTN,
319
+ # if Flash attention kernel is off, then math kernel needs to be enabled
320
+ enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
321
+ enable_mem_efficient=OLD_GPU,
322
+ ):
323
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
324
+
325
+ out = self._recombine_heads(out)
326
+ out = self.out_proj(out)
327
+
328
+ return out
third_parts/sam2/modeling/sam2_base.py ADDED
@@ -0,0 +1,830 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.distributed
9
+ import torch.nn.functional as F
10
+
11
+ from torch.nn.init import trunc_normal_
12
+
13
+ from third_parts.sam2.modeling.sam.mask_decoder import MaskDecoder
14
+ from third_parts.sam2.modeling.sam.prompt_encoder import PromptEncoder
15
+ from third_parts.sam2.modeling.sam.transformer import TwoWayTransformer
16
+ from third_parts.sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames
17
+
18
+ # a large negative value as a placeholder score for missing objects
19
+ NO_OBJ_SCORE = -1024.0
20
+
21
+
22
+ class SAM2Base(torch.nn.Module):
23
+ def __init__(
24
+ self,
25
+ image_encoder,
26
+ memory_attention,
27
+ memory_encoder,
28
+ num_maskmem=7, # default 1 input frame + 6 previous frames
29
+ image_size=512,
30
+ backbone_stride=16, # stride of the image backbone output
31
+ sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob
32
+ sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob
33
+ # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks
34
+ binarize_mask_from_pts_for_mem_enc=False,
35
+ use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder
36
+ # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit,
37
+ # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model
38
+ # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM.
39
+ max_cond_frames_in_attn=-1,
40
+ # on the first frame, whether to directly add the no-memory embedding to the image feature
41
+ # (instead of using the transformer encoder)
42
+ directly_add_no_mem_embed=False,
43
+ # whether to use high-resolution feature maps in the SAM mask decoder
44
+ use_high_res_features_in_sam=False,
45
+ # whether to output multiple (3) masks for the first click on initial conditioning frames
46
+ multimask_output_in_sam=False,
47
+ # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`;
48
+ # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points)
49
+ multimask_min_pt_num=1,
50
+ multimask_max_pt_num=1,
51
+ # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`)
52
+ multimask_output_for_tracking=False,
53
+ # Whether to use multimask tokens for obj ptr; Only relevant when both
54
+ # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True
55
+ use_multimask_token_for_obj_ptr: bool = False,
56
+ # whether to use sigmoid to restrict ious prediction to [0-1]
57
+ iou_prediction_use_sigmoid=False,
58
+ # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5).
59
+ # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
60
+ # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
61
+ memory_temporal_stride_for_eval=1,
62
+ # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
63
+ # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
64
+ add_all_frames_to_correct_as_cond=False,
65
+ # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
66
+ non_overlap_masks_for_mem_enc=False,
67
+ # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
68
+ use_obj_ptrs_in_encoder=False,
69
+ # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`)
70
+ max_obj_ptrs_in_encoder=16,
71
+ # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`)
72
+ add_tpos_enc_to_obj_ptrs=True,
73
+ # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference
74
+ # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
75
+ proj_tpos_enc_in_obj_ptrs=False,
76
+ # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation
77
+ # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking)
78
+ only_obj_ptrs_in_the_past_for_eval=False,
79
+ # Whether to predict if there is an object in the frame
80
+ pred_obj_scores: bool = False,
81
+ # Whether to use an MLP to predict object scores
82
+ pred_obj_scores_mlp: bool = False,
83
+ # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True;
84
+ # Whether to have a fixed no obj pointer when there is no object present
85
+ # or to use it as an additive embedding with obj_ptr produced by decoder
86
+ fixed_no_obj_ptr: bool = False,
87
+ # Soft no object, i.e. mix in no_obj_ptr softly,
88
+ # hope to make recovery easier if there is a mistake and mitigate accumulation of errors
89
+ soft_no_obj_ptr: bool = False,
90
+ use_mlp_for_obj_ptr_proj: bool = False,
91
+ # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
92
+ sam_mask_decoder_extra_args=None,
93
+ compile_image_encoder: bool = False,
94
+ ):
95
+ super().__init__()
96
+
97
+ # Part 1: the image backbone
98
+ self.image_encoder = image_encoder
99
+ # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
100
+ self.use_high_res_features_in_sam = use_high_res_features_in_sam
101
+ self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
102
+ self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
103
+ self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
104
+ if use_obj_ptrs_in_encoder:
105
+ # A conv layer to downsample the mask prompt to stride 4 (the same stride as
106
+ # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
107
+ # so that it can be fed into the SAM mask decoder to generate a pointer.
108
+ self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
109
+ self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
110
+ if proj_tpos_enc_in_obj_ptrs:
111
+ assert add_tpos_enc_to_obj_ptrs # these options need to be used together
112
+ self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
113
+ self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
114
+
115
+ # Part 2: memory attention to condition current frame's visual features
116
+ # with memories (and obj ptrs) from past frames
117
+ self.memory_attention = memory_attention
118
+ self.hidden_dim = memory_attention.d_model
119
+
120
+ # Part 3: memory encoder for the previous frame's outputs
121
+ self.memory_encoder = memory_encoder
122
+ self.mem_dim = self.hidden_dim
123
+ if hasattr(self.memory_encoder, "out_proj") and hasattr(
124
+ self.memory_encoder.out_proj, "weight"
125
+ ):
126
+ # if there is compression of memories along channel dim
127
+ self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
128
+ self.num_maskmem = num_maskmem # Number of memories accessible
129
+ # Temporal encoding of the memories
130
+ self.maskmem_tpos_enc = torch.nn.Parameter(
131
+ torch.zeros(num_maskmem, 1, 1, self.mem_dim)
132
+ )
133
+ trunc_normal_(self.maskmem_tpos_enc, std=0.02)
134
+ # a single token to indicate no memory embedding from previous frames
135
+ self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
136
+ self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
137
+ trunc_normal_(self.no_mem_embed, std=0.02)
138
+ trunc_normal_(self.no_mem_pos_enc, std=0.02)
139
+ self.directly_add_no_mem_embed = directly_add_no_mem_embed
140
+ # Apply sigmoid to the output raw mask logits (to turn them from
141
+ # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
142
+ self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
143
+ self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
144
+ self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
145
+ self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
146
+ self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
147
+ # On frames with mask input, whether to directly output the input mask without
148
+ # using a SAM prompt encoder + mask decoder
149
+ self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
150
+ self.multimask_output_in_sam = multimask_output_in_sam
151
+ self.multimask_min_pt_num = multimask_min_pt_num
152
+ self.multimask_max_pt_num = multimask_max_pt_num
153
+ self.multimask_output_for_tracking = multimask_output_for_tracking
154
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
155
+ self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
156
+
157
+ # Part 4: SAM-style prompt encoder (for both mask and point inputs)
158
+ # and SAM-style mask decoder for the final mask output
159
+ self.image_size = image_size
160
+ self.backbone_stride = backbone_stride
161
+ self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
162
+ self.pred_obj_scores = pred_obj_scores
163
+ self.pred_obj_scores_mlp = pred_obj_scores_mlp
164
+ self.fixed_no_obj_ptr = fixed_no_obj_ptr
165
+ self.soft_no_obj_ptr = soft_no_obj_ptr
166
+ if self.fixed_no_obj_ptr:
167
+ assert self.pred_obj_scores
168
+ assert self.use_obj_ptrs_in_encoder
169
+ if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
170
+ self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
171
+ trunc_normal_(self.no_obj_ptr, std=0.02)
172
+ self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
173
+
174
+ self._build_sam_heads()
175
+ self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
176
+ self.max_cond_frames_in_attn = max_cond_frames_in_attn
177
+
178
+ # Model compilation
179
+ if compile_image_encoder:
180
+ # Compile the forward function (not the full module) to allow loading checkpoints.
181
+ print(
182
+ "Image encoder compilation is enabled. First forward pass will be slow."
183
+ )
184
+ self.image_encoder.forward = torch.compile(
185
+ self.image_encoder.forward,
186
+ mode="max-autotune",
187
+ fullgraph=True,
188
+ dynamic=False,
189
+ )
190
+
191
+ @property
192
+ def device(self):
193
+ return next(self.parameters()).device
194
+
195
+ def forward(self, *args, **kwargs):
196
+ raise NotImplementedError(
197
+ "Please use the corresponding methods in SAM2VideoPredictor for inference."
198
+ "See notebooks/video_predictor_example.ipynb for an example."
199
+ )
200
+
201
+ def _build_sam_heads(self):
202
+ """Build SAM-style prompt encoder and mask decoder."""
203
+ self.sam_prompt_embed_dim = self.hidden_dim
204
+ self.sam_image_embedding_size = self.image_size // self.backbone_stride
205
+
206
+ # build PromptEncoder and MaskDecoder from SAM
207
+ # (their hyperparameters like `mask_in_chans=16` are from SAM code)
208
+ self.sam_prompt_encoder = PromptEncoder(
209
+ embed_dim=self.sam_prompt_embed_dim,
210
+ image_embedding_size=(
211
+ self.sam_image_embedding_size,
212
+ self.sam_image_embedding_size,
213
+ ),
214
+ input_image_size=(self.image_size, self.image_size),
215
+ mask_in_chans=16,
216
+ )
217
+ self.sam_mask_decoder = MaskDecoder(
218
+ num_multimask_outputs=3,
219
+ transformer=TwoWayTransformer(
220
+ depth=2,
221
+ embedding_dim=self.sam_prompt_embed_dim,
222
+ mlp_dim=2048,
223
+ num_heads=8,
224
+ ),
225
+ transformer_dim=self.sam_prompt_embed_dim,
226
+ iou_head_depth=3,
227
+ iou_head_hidden_dim=256,
228
+ use_high_res_features=self.use_high_res_features_in_sam,
229
+ iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
230
+ pred_obj_scores=self.pred_obj_scores,
231
+ pred_obj_scores_mlp=self.pred_obj_scores_mlp,
232
+ use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
233
+ **(self.sam_mask_decoder_extra_args or {}),
234
+ )
235
+ if self.use_obj_ptrs_in_encoder:
236
+ # a linear projection on SAM output tokens to turn them into object pointers
237
+ self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
238
+ if self.use_mlp_for_obj_ptr_proj:
239
+ self.obj_ptr_proj = MLP(
240
+ self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
241
+ )
242
+ else:
243
+ self.obj_ptr_proj = torch.nn.Identity()
244
+ if self.proj_tpos_enc_in_obj_ptrs:
245
+ # a linear projection on temporal positional encoding in object pointers to
246
+ # avoid potential interference with spatial positional encoding
247
+ self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
248
+ else:
249
+ self.obj_ptr_tpos_proj = torch.nn.Identity()
250
+
251
+ def _forward_sam_heads(
252
+ self,
253
+ backbone_features,
254
+ point_inputs=None,
255
+ mask_inputs=None,
256
+ high_res_features=None,
257
+ multimask_output=False,
258
+ ):
259
+ """
260
+ Forward SAM prompt encoders and mask heads.
261
+
262
+ Inputs:
263
+ - backbone_features: image features of [B, C, H, W] shape
264
+ - point_inputs: a dictionary with "point_coords" and "point_labels", where
265
+ 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
266
+ absolute pixel-unit coordinate in (x, y) format of the P input points
267
+ 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
268
+ positive clicks, 0 means negative clicks, and -1 means padding
269
+ - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
270
+ same spatial size as the image.
271
+ - high_res_features: either 1) None or 2) or a list of length 2 containing
272
+ two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
273
+ which will be used as high-resolution feature maps for SAM decoder.
274
+ - multimask_output: if it's True, we output 3 candidate masks and their 3
275
+ corresponding IoU estimates, and if it's False, we output only 1 mask and
276
+ its corresponding IoU estimate.
277
+
278
+ Outputs:
279
+ - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
280
+ `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
281
+ output mask logits (before sigmoid) for the low-resolution masks, with 4x
282
+ the resolution (1/4 stride) of the input backbone_features.
283
+ - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
284
+ if `multimask_output=True` and M = 1 if `multimask_output=False`),
285
+ upsampled from the low-resolution masks, with shape size as the image
286
+ (stride is 1 pixel).
287
+ - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
288
+ if `multimask_output=False`), the estimated IoU of each output mask.
289
+ - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
290
+ If `multimask_output=True`, it's the mask with the highest IoU estimate.
291
+ If `multimask_output=False`, it's the same as `low_res_multimasks`.
292
+ - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
293
+ If `multimask_output=True`, it's the mask with the highest IoU estimate.
294
+ If `multimask_output=False`, it's the same as `high_res_multimasks`.
295
+ - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
296
+ based on the output token from the SAM mask decoder.
297
+ """
298
+ B = backbone_features.size(0)
299
+ device = backbone_features.device
300
+ assert backbone_features.size(1) == self.sam_prompt_embed_dim
301
+ assert backbone_features.size(2) == self.sam_image_embedding_size
302
+ assert backbone_features.size(3) == self.sam_image_embedding_size
303
+
304
+ # a) Handle point prompts
305
+ if point_inputs is not None:
306
+ sam_point_coords = point_inputs["point_coords"]
307
+ sam_point_labels = point_inputs["point_labels"]
308
+ assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
309
+ else:
310
+ # If no points are provide, pad with an empty point (with label -1)
311
+ sam_point_coords = torch.zeros(B, 1, 2, device=device)
312
+ sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
313
+
314
+ # b) Handle mask prompts
315
+ if mask_inputs is not None:
316
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
317
+ # and feed it as a dense mask prompt into the SAM mask encoder
318
+ assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
319
+ if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
320
+ sam_mask_prompt = F.interpolate(
321
+ mask_inputs.float(),
322
+ size=self.sam_prompt_encoder.mask_input_size,
323
+ align_corners=False,
324
+ mode="bilinear",
325
+ antialias=True, # use antialias for downsampling
326
+ )
327
+ else:
328
+ sam_mask_prompt = mask_inputs
329
+ else:
330
+ # Otherwise, simply feed None (and SAM's prompt encoder will add
331
+ # a learned `no_mask_embed` to indicate no mask input in this case).
332
+ sam_mask_prompt = None
333
+
334
+ sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
335
+ points=(sam_point_coords, sam_point_labels),
336
+ boxes=None,
337
+ masks=sam_mask_prompt,
338
+ )
339
+ (
340
+ low_res_multimasks,
341
+ ious,
342
+ sam_output_tokens,
343
+ object_score_logits,
344
+ ) = self.sam_mask_decoder(
345
+ image_embeddings=backbone_features,
346
+ image_pe=self.sam_prompt_encoder.get_dense_pe(),
347
+ sparse_prompt_embeddings=sparse_embeddings,
348
+ dense_prompt_embeddings=dense_embeddings,
349
+ multimask_output=multimask_output,
350
+ repeat_image=False, # the image is already batched
351
+ high_res_features=high_res_features,
352
+ )
353
+ if self.pred_obj_scores:
354
+ is_obj_appearing = object_score_logits > 0
355
+
356
+ # Mask used for spatial memories is always a *hard* choice between obj and no obj,
357
+ # consistent with the actual mask prediction
358
+ low_res_multimasks = torch.where(
359
+ is_obj_appearing[:, None, None],
360
+ low_res_multimasks,
361
+ NO_OBJ_SCORE,
362
+ )
363
+
364
+ # convert masks from possibly bfloat16 (or float16) to float32
365
+ # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
366
+ _dtype = low_res_multimasks.dtype
367
+ # low_res_multimasks = low_res_multimasks.float()
368
+ high_res_multimasks = F.interpolate(
369
+ low_res_multimasks.float(),
370
+ size=(self.image_size, self.image_size),
371
+ mode="bilinear",
372
+ align_corners=False,
373
+ ).to(_dtype)
374
+
375
+ sam_output_token = sam_output_tokens[:, 0]
376
+ if multimask_output:
377
+ # take the best mask prediction (with the highest IoU estimation)
378
+ best_iou_inds = torch.argmax(ious, dim=-1)
379
+ batch_inds = torch.arange(B, device=device)
380
+ low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
381
+ high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
382
+ if sam_output_tokens.size(1) > 1:
383
+ sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
384
+ else:
385
+ low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
386
+
387
+ # Extract object pointer from the SAM output token (with occlusion handling)
388
+ obj_ptr = self.obj_ptr_proj(sam_output_token)
389
+ if self.pred_obj_scores:
390
+ # Allow *soft* no obj ptr, unlike for masks
391
+ if self.soft_no_obj_ptr:
392
+ # Only hard possible with gt
393
+ assert not self.teacher_force_obj_scores_for_mem
394
+ lambda_is_obj_appearing = object_score_logits.sigmoid()
395
+ else:
396
+ lambda_is_obj_appearing = is_obj_appearing.float()
397
+
398
+ if self.fixed_no_obj_ptr:
399
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
400
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
401
+
402
+ return (
403
+ low_res_multimasks,
404
+ high_res_multimasks,
405
+ ious,
406
+ low_res_masks,
407
+ high_res_masks,
408
+ obj_ptr,
409
+ object_score_logits,
410
+ )
411
+
412
+ def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
413
+ """
414
+ Directly turn binary `mask_inputs` into a output mask logits without using SAM.
415
+ (same input and output shapes as in _forward_sam_heads above).
416
+ """
417
+ # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
418
+ out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
419
+ mask_inputs_float = mask_inputs.float()
420
+ high_res_masks = mask_inputs_float * out_scale + out_bias
421
+ low_res_masks = F.interpolate(
422
+ high_res_masks,
423
+ size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
424
+ align_corners=False,
425
+ mode="bilinear",
426
+ antialias=True, # use antialias for downsampling
427
+ )
428
+ # a dummy IoU prediction of all 1's under mask input
429
+ ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
430
+ if not self.use_obj_ptrs_in_encoder:
431
+ # all zeros as a dummy object pointer (of shape [B, C])
432
+ obj_ptr = torch.zeros(
433
+ mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device
434
+ )
435
+ else:
436
+ # produce an object pointer using the SAM decoder from the mask input
437
+ _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
438
+ backbone_features=backbone_features,
439
+ mask_inputs=self.mask_downsample(mask_inputs_float),
440
+ high_res_features=high_res_features,
441
+ )
442
+ # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
443
+ # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
444
+ # on the object_scores from the SAM decoder.
445
+ is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
446
+ is_obj_appearing = is_obj_appearing[..., None]
447
+ lambda_is_obj_appearing = is_obj_appearing.float()
448
+ object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
449
+ if self.pred_obj_scores:
450
+ if self.fixed_no_obj_ptr:
451
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
452
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
453
+
454
+ return (
455
+ low_res_masks,
456
+ high_res_masks,
457
+ ious,
458
+ low_res_masks,
459
+ high_res_masks,
460
+ obj_ptr,
461
+ object_score_logits,
462
+ )
463
+
464
+ def forward_image(self, img_batch: torch.Tensor):
465
+ """Get the image feature on the input batch."""
466
+ backbone_out = self.image_encoder(img_batch)
467
+ if self.use_high_res_features_in_sam:
468
+ # precompute projected level 0 and level 1 features in SAM decoder
469
+ # to avoid running it again on every SAM click
470
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
471
+ backbone_out["backbone_fpn"][0]
472
+ )
473
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
474
+ backbone_out["backbone_fpn"][1]
475
+ )
476
+ return backbone_out
477
+
478
+ def _prepare_backbone_features(self, backbone_out):
479
+ """Prepare and flatten visual features."""
480
+ backbone_out = backbone_out.copy()
481
+ assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
482
+ assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
483
+
484
+ feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
485
+ vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
486
+
487
+ feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
488
+ # flatten NxCxHxW to HWxNxC
489
+ vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
490
+ vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
491
+
492
+ return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
493
+
494
+ def _prepare_memory_conditioned_features(
495
+ self,
496
+ frame_idx,
497
+ is_init_cond_frame,
498
+ current_vision_feats,
499
+ current_vision_pos_embeds,
500
+ feat_sizes,
501
+ output_dict,
502
+ num_frames,
503
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
504
+ ):
505
+ """Fuse the current frame's visual feature map with previous memory."""
506
+ B = current_vision_feats[-1].size(1) # batch size on this frame
507
+ C = self.hidden_dim
508
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
509
+ device = current_vision_feats[-1].device
510
+ # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
511
+ # In this case, we skip the fusion with any memory.
512
+ if self.num_maskmem == 0: # Disable memory and skip fusion
513
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
514
+ return pix_feat
515
+
516
+ num_obj_ptr_tokens = 0
517
+ # Step 1: condition the visual features of the current frame on previous memories
518
+ if not is_init_cond_frame:
519
+ # Retrieve the memories encoded with the maskmem backbone
520
+ to_cat_memory, to_cat_memory_pos_embed = [], []
521
+ # Add conditioning frames's output first (all cond frames have t_pos=0 for
522
+ # when getting temporal positional embedding below)
523
+ assert len(output_dict["cond_frame_outputs"]) > 0
524
+ # Select a maximum number of temporally closest cond frames for cross attention
525
+ cond_outputs = output_dict["cond_frame_outputs"]
526
+ selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
527
+ frame_idx, cond_outputs, self.max_cond_frames_in_attn
528
+ )
529
+ t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
530
+ # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
531
+ # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
532
+ # We also allow taking the memory frame non-consecutively (with r>1), in which case
533
+ # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
534
+ r = self.memory_temporal_stride_for_eval
535
+ for t_pos in range(1, self.num_maskmem):
536
+ t_rel = self.num_maskmem - t_pos # how many frames before current frame
537
+ if t_rel == 1:
538
+ # for t_rel == 1, we take the last frame (regardless of r)
539
+ if not track_in_reverse:
540
+ # the frame immediately before this frame (i.e. frame_idx - 1)
541
+ prev_frame_idx = frame_idx - t_rel
542
+ else:
543
+ # the frame immediately after this frame (i.e. frame_idx + 1)
544
+ prev_frame_idx = frame_idx + t_rel
545
+ else:
546
+ # for t_rel >= 2, we take the memory frame from every r-th frames
547
+ if not track_in_reverse:
548
+ # first find the nearest frame among every r-th frames before this frame
549
+ # for r=1, this would be (frame_idx - 2)
550
+ prev_frame_idx = ((frame_idx - 2) // r) * r
551
+ # then seek further among every r-th frames
552
+ prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
553
+ else:
554
+ # first find the nearest frame among every r-th frames after this frame
555
+ # for r=1, this would be (frame_idx + 2)
556
+ prev_frame_idx = -(-(frame_idx + 2) // r) * r
557
+ # then seek further among every r-th frames
558
+ prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
559
+ out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
560
+ if out is None:
561
+ # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
562
+ # frames, we still attend to it as if it's a non-conditioning frame.
563
+ out = unselected_cond_outputs.get(prev_frame_idx, None)
564
+ t_pos_and_prevs.append((t_pos, out))
565
+
566
+ for t_pos, prev in t_pos_and_prevs:
567
+ if prev is None:
568
+ continue # skip padding frames
569
+ # "maskmem_features" might have been offloaded to CPU in demo use cases,
570
+ # so we load it back to GPU (it's a no-op if it's already on GPU).
571
+ feats = prev["maskmem_features"].cuda(non_blocking=True)
572
+ to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
573
+ # Spatial positional encoding (it might have been offloaded to CPU in eval)
574
+ maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
575
+ maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
576
+ # Temporal positional encoding
577
+ maskmem_enc = (
578
+ maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
579
+ )
580
+ to_cat_memory_pos_embed.append(maskmem_enc)
581
+
582
+ # Construct the list of past object pointers
583
+ if self.use_obj_ptrs_in_encoder:
584
+ max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
585
+ # First add those object pointers from selected conditioning frames
586
+ # (optionally, only include object pointers in the past during evaluation)
587
+ if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
588
+ ptr_cond_outputs = {
589
+ t: out
590
+ for t, out in selected_cond_outputs.items()
591
+ if (t >= frame_idx if track_in_reverse else t <= frame_idx)
592
+ }
593
+ else:
594
+ ptr_cond_outputs = selected_cond_outputs
595
+ pos_and_ptrs = [
596
+ # Temporal pos encoding contains how far away each pointer is from current frame
597
+ (abs(frame_idx - t), out["obj_ptr"])
598
+ for t, out in ptr_cond_outputs.items()
599
+ ]
600
+ # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
601
+ for t_diff in range(1, max_obj_ptrs_in_encoder):
602
+ t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
603
+ if t < 0 or (num_frames is not None and t >= num_frames):
604
+ break
605
+ out = output_dict["non_cond_frame_outputs"].get(
606
+ t, unselected_cond_outputs.get(t, None)
607
+ )
608
+ if out is not None:
609
+ pos_and_ptrs.append((t_diff, out["obj_ptr"]))
610
+ # If we have at least one object pointer, add them to the across attention
611
+ if len(pos_and_ptrs) > 0:
612
+ pos_list, ptrs_list = zip(*pos_and_ptrs)
613
+ # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
614
+ obj_ptrs = torch.stack(ptrs_list, dim=0)
615
+ # a temporal positional embedding based on how far each object pointer is from
616
+ # the current frame (sine embedding normalized by the max pointer num).
617
+ if self.add_tpos_enc_to_obj_ptrs:
618
+ t_diff_max = max_obj_ptrs_in_encoder - 1
619
+ tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
620
+ obj_pos = torch.tensor(pos_list, device=device)
621
+ obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
622
+ obj_pos = self.obj_ptr_tpos_proj(obj_pos)
623
+ obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
624
+ else:
625
+ obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
626
+ if self.mem_dim < C:
627
+ # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
628
+ obj_ptrs = obj_ptrs.reshape(
629
+ -1, B, C // self.mem_dim, self.mem_dim
630
+ )
631
+ obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
632
+ obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
633
+ to_cat_memory.append(obj_ptrs)
634
+ to_cat_memory_pos_embed.append(obj_pos)
635
+ num_obj_ptr_tokens = obj_ptrs.shape[0]
636
+ else:
637
+ num_obj_ptr_tokens = 0
638
+ else:
639
+ # for initial conditioning frames, encode them without using any previous memory
640
+ if self.directly_add_no_mem_embed:
641
+ # directly add no-mem embedding (instead of using the transformer encoder)
642
+ pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
643
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
644
+ return pix_feat_with_mem
645
+
646
+ # Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder)
647
+ to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
648
+ to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
649
+
650
+ # Step 2: Concatenate the memories and forward through the transformer encoder
651
+ memory = torch.cat(to_cat_memory, dim=0)
652
+ memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
653
+
654
+ pix_feat_with_mem = self.memory_attention(
655
+ curr=current_vision_feats,
656
+ curr_pos=current_vision_pos_embeds,
657
+ memory=memory,
658
+ memory_pos=memory_pos_embed,
659
+ num_obj_ptr_tokens=num_obj_ptr_tokens,
660
+ )
661
+ # reshape the output (HW)BC => BCHW
662
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
663
+ return pix_feat_with_mem
664
+
665
+ def _encode_new_memory(
666
+ self,
667
+ current_vision_feats,
668
+ feat_sizes,
669
+ pred_masks_high_res,
670
+ is_mask_from_pts,
671
+ ):
672
+ """Encode the current image and its prediction into a memory feature."""
673
+ B = current_vision_feats[-1].size(1) # batch size on this frame
674
+ C = self.hidden_dim
675
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
676
+ # top-level feature, (HW)BC => BCHW
677
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
678
+ if self.non_overlap_masks_for_mem_enc and not self.training:
679
+ # optionally, apply non-overlapping constraints to the masks (it's applied
680
+ # in the batch dimension and should only be used during eval, where all
681
+ # the objects come from the same video under batch size 1).
682
+ pred_masks_high_res = self._apply_non_overlapping_constraints(
683
+ pred_masks_high_res
684
+ )
685
+ # scale the raw mask logits with a temperature before applying sigmoid
686
+ binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
687
+ if binarize and not self.training:
688
+ mask_for_mem = (pred_masks_high_res > 0).float()
689
+ else:
690
+ # apply sigmoid on the raw mask logits to turn them into range (0, 1)
691
+ mask_for_mem = torch.sigmoid(pred_masks_high_res)
692
+ # apply scale and bias terms to the sigmoid probabilities
693
+ if self.sigmoid_scale_for_mem_enc != 1.0:
694
+ mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
695
+ if self.sigmoid_bias_for_mem_enc != 0.0:
696
+ mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
697
+ maskmem_out = self.memory_encoder(
698
+ pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
699
+ )
700
+ maskmem_features = maskmem_out["vision_features"]
701
+ maskmem_pos_enc = maskmem_out["vision_pos_enc"]
702
+
703
+ return maskmem_features, maskmem_pos_enc
704
+
705
+ def track_step(
706
+ self,
707
+ frame_idx,
708
+ is_init_cond_frame,
709
+ current_vision_feats,
710
+ current_vision_pos_embeds,
711
+ feat_sizes,
712
+ point_inputs,
713
+ mask_inputs,
714
+ output_dict,
715
+ num_frames,
716
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
717
+ # Whether to run the memory encoder on the predicted masks. Sometimes we might want
718
+ # to skip the memory encoder with `run_mem_encoder=False`. For example,
719
+ # in demo we might call `track_step` multiple times for each user click,
720
+ # and only encode the memory when the user finalizes their clicks. And in ablation
721
+ # settings like SAM training on static images, we don't need the memory encoder.
722
+ run_mem_encoder=True,
723
+ # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
724
+ prev_sam_mask_logits=None,
725
+ ):
726
+ current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
727
+ # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
728
+ if len(current_vision_feats) > 1:
729
+ high_res_features = [
730
+ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
731
+ for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
732
+ ]
733
+ else:
734
+ high_res_features = None
735
+ if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
736
+ # When use_mask_input_as_output_without_sam=True, we directly output the mask input
737
+ # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
738
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0)
739
+ pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
740
+ sam_outputs = self._use_mask_as_output(
741
+ pix_feat, high_res_features, mask_inputs
742
+ )
743
+ else:
744
+ # fused the visual feature with previous memory features in the memory bank
745
+ pix_feat_with_mem = self._prepare_memory_conditioned_features(
746
+ frame_idx=frame_idx,
747
+ is_init_cond_frame=is_init_cond_frame,
748
+ current_vision_feats=current_vision_feats[-1:],
749
+ current_vision_pos_embeds=current_vision_pos_embeds[-1:],
750
+ feat_sizes=feat_sizes[-1:],
751
+ output_dict=output_dict,
752
+ num_frames=num_frames,
753
+ track_in_reverse=track_in_reverse,
754
+ )
755
+ # apply SAM-style segmentation head
756
+ # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
757
+ # e.g. in demo where such logits come from earlier interaction instead of correction sampling
758
+ # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
759
+ if prev_sam_mask_logits is not None:
760
+ assert point_inputs is not None and mask_inputs is None
761
+ mask_inputs = prev_sam_mask_logits
762
+ multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
763
+ sam_outputs = self._forward_sam_heads(
764
+ backbone_features=pix_feat_with_mem,
765
+ point_inputs=point_inputs,
766
+ mask_inputs=mask_inputs,
767
+ high_res_features=high_res_features,
768
+ multimask_output=multimask_output,
769
+ )
770
+ (
771
+ _,
772
+ _,
773
+ _,
774
+ low_res_masks,
775
+ high_res_masks,
776
+ obj_ptr,
777
+ _,
778
+ ) = sam_outputs
779
+
780
+ current_out["pred_masks"] = low_res_masks
781
+ current_out["pred_masks_high_res"] = high_res_masks
782
+ current_out["obj_ptr"] = obj_ptr
783
+
784
+ # Finally run the memory encoder on the predicted mask to encode
785
+ # it into a new memory feature (that can be used in future frames)
786
+ if run_mem_encoder and self.num_maskmem > 0:
787
+ high_res_masks_for_mem_enc = high_res_masks
788
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
789
+ current_vision_feats=current_vision_feats,
790
+ feat_sizes=feat_sizes,
791
+ pred_masks_high_res=high_res_masks_for_mem_enc,
792
+ is_mask_from_pts=(point_inputs is not None),
793
+ )
794
+ current_out["maskmem_features"] = maskmem_features
795
+ current_out["maskmem_pos_enc"] = maskmem_pos_enc
796
+ else:
797
+ current_out["maskmem_features"] = None
798
+ current_out["maskmem_pos_enc"] = None
799
+
800
+ return current_out
801
+
802
+ def _use_multimask(self, is_init_cond_frame, point_inputs):
803
+ """Whether to use multimask output in the SAM head."""
804
+ num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
805
+ multimask_output = (
806
+ self.multimask_output_in_sam
807
+ and (is_init_cond_frame or self.multimask_output_for_tracking)
808
+ and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
809
+ )
810
+ return multimask_output
811
+
812
+ def _apply_non_overlapping_constraints(self, pred_masks):
813
+ """
814
+ Apply non-overlapping constraints to the object scores in pred_masks. Here we
815
+ keep only the highest scoring object at each spatial location in pred_masks.
816
+ """
817
+ batch_size = pred_masks.size(0)
818
+ if batch_size == 1:
819
+ return pred_masks
820
+
821
+ device = pred_masks.device
822
+ # "max_obj_inds": object index of the object with the highest score at each location
823
+ max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
824
+ # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
825
+ batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
826
+ keep = max_obj_inds == batch_obj_inds
827
+ # suppress overlapping regions' scores below -10.0 so that the foreground regions
828
+ # don't overlap (here sigmoid(-10.0)=4.5398e-05)
829
+ pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
830
+ return pred_masks
third_parts/sam2/modeling/sam2_utils.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import copy
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
16
+ """
17
+ Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
18
+ that are temporally closest to the current frame at `frame_idx`. Here, we take
19
+ - a) the closest conditioning frame before `frame_idx` (if any);
20
+ - b) the closest conditioning frame after `frame_idx` (if any);
21
+ - c) any other temporally closest conditioning frames until reaching a total
22
+ of `max_cond_frame_num` conditioning frames.
23
+
24
+ Outputs:
25
+ - selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
26
+ - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
27
+ """
28
+ if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
29
+ selected_outputs = cond_frame_outputs
30
+ unselected_outputs = {}
31
+ else:
32
+ assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
33
+ selected_outputs = {}
34
+
35
+ # the closest conditioning frame before `frame_idx` (if any)
36
+ idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
37
+ if idx_before is not None:
38
+ selected_outputs[idx_before] = cond_frame_outputs[idx_before]
39
+
40
+ # the closest conditioning frame after `frame_idx` (if any)
41
+ idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
42
+ if idx_after is not None:
43
+ selected_outputs[idx_after] = cond_frame_outputs[idx_after]
44
+
45
+ # add other temporally closest conditioning frames until reaching a total
46
+ # of `max_cond_frame_num` conditioning frames.
47
+ num_remain = max_cond_frame_num - len(selected_outputs)
48
+ inds_remain = sorted(
49
+ (t for t in cond_frame_outputs if t not in selected_outputs),
50
+ key=lambda x: abs(x - frame_idx),
51
+ )[:num_remain]
52
+ selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
53
+ unselected_outputs = {
54
+ t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
55
+ }
56
+
57
+ return selected_outputs, unselected_outputs
58
+
59
+
60
+ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
61
+ """
62
+ Get 1D sine positional embedding as in the original Transformer paper.
63
+ """
64
+ pe_dim = dim // 2
65
+ dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
66
+ dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
67
+
68
+ pos_embed = pos_inds.unsqueeze(-1) / dim_t
69
+ pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
70
+ return pos_embed
71
+
72
+
73
+ def get_activation_fn(activation):
74
+ """Return an activation function given a string"""
75
+ if activation == "relu":
76
+ return F.relu
77
+ if activation == "gelu":
78
+ return F.gelu
79
+ if activation == "glu":
80
+ return F.glu
81
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
82
+
83
+
84
+ def get_clones(module, N):
85
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
86
+
87
+
88
+ class DropPath(nn.Module):
89
+ # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
90
+ def __init__(self, drop_prob=0.0, scale_by_keep=True):
91
+ super(DropPath, self).__init__()
92
+ self.drop_prob = drop_prob
93
+ self.scale_by_keep = scale_by_keep
94
+
95
+ def forward(self, x):
96
+ if self.drop_prob == 0.0 or not self.training:
97
+ return x
98
+ keep_prob = 1 - self.drop_prob
99
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
100
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
101
+ if keep_prob > 0.0 and self.scale_by_keep:
102
+ random_tensor.div_(keep_prob)
103
+ return x * random_tensor
104
+
105
+
106
+ # Lightly adapted from
107
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
108
+ class MLP(nn.Module):
109
+ def __init__(
110
+ self,
111
+ input_dim: int,
112
+ hidden_dim: int,
113
+ output_dim: int,
114
+ num_layers: int,
115
+ activation: nn.Module = nn.ReLU,
116
+ sigmoid_output: bool = False,
117
+ ) -> None:
118
+ super().__init__()
119
+ self.num_layers = num_layers
120
+ h = [hidden_dim] * (num_layers - 1)
121
+ self.layers = nn.ModuleList(
122
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
123
+ )
124
+ self.sigmoid_output = sigmoid_output
125
+ self.act = activation()
126
+
127
+ def forward(self, x):
128
+ for i, layer in enumerate(self.layers):
129
+ x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
130
+ if self.sigmoid_output:
131
+ x = F.sigmoid(x)
132
+ return x
133
+
134
+
135
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
136
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
137
+ class LayerNorm2d(nn.Module):
138
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
139
+ super().__init__()
140
+ self.weight = nn.Parameter(torch.ones(num_channels))
141
+ self.bias = nn.Parameter(torch.zeros(num_channels))
142
+ self.eps = eps
143
+
144
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
145
+ u = x.mean(1, keepdim=True)
146
+ s = (x - u).pow(2).mean(1, keepdim=True)
147
+ x = (x - u) / torch.sqrt(s + self.eps)
148
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
149
+ return x
third_parts/sam2/sam2_configs/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
third_parts/sam2/sam2_configs/sam2_hiera_b+.yaml ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: third_parts.sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: third_parts.sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: third_parts.sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 112
12
+ num_heads: 2
13
+ neck:
14
+ _target_: third_parts.sam2.modeling.backbones.image_encoder.FpnNeck
15
+ position_encoding:
16
+ _target_: third_parts.sam2.modeling.position_encoding.PositionEmbeddingSine
17
+ num_pos_feats: 256
18
+ normalize: true
19
+ scale: null
20
+ temperature: 10000
21
+ d_model: 256
22
+ backbone_channel_list: [896, 448, 224, 112]
23
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24
+ fpn_interp_model: nearest
25
+
26
+ memory_attention:
27
+ _target_: third_parts.sam2.modeling.memory_attention.MemoryAttention
28
+ d_model: 256
29
+ pos_enc_at_input: true
30
+ layer:
31
+ _target_: third_parts.sam2.modeling.memory_attention.MemoryAttentionLayer
32
+ activation: relu
33
+ dim_feedforward: 2048
34
+ dropout: 0.1
35
+ pos_enc_at_attn: false
36
+ self_attention:
37
+ _target_: third_parts.sam2.modeling.sam.transformer.RoPEAttention
38
+ rope_theta: 10000.0
39
+ feat_sizes: [32, 32]
40
+ embedding_dim: 256
41
+ num_heads: 1
42
+ downsample_rate: 1
43
+ dropout: 0.1
44
+ d_model: 256
45
+ pos_enc_at_cross_attn_keys: true
46
+ pos_enc_at_cross_attn_queries: false
47
+ cross_attention:
48
+ _target_: third_parts.sam2.modeling.sam.transformer.RoPEAttention
49
+ rope_theta: 10000.0
50
+ feat_sizes: [32, 32]
51
+ rope_k_repeat: True
52
+ embedding_dim: 256
53
+ num_heads: 1
54
+ downsample_rate: 1
55
+ dropout: 0.1
56
+ kv_in_dim: 64
57
+ num_layers: 4
58
+
59
+ memory_encoder:
60
+ _target_: third_parts.sam2.modeling.memory_encoder.MemoryEncoder
61
+ out_dim: 64
62
+ position_encoding:
63
+ _target_: third_parts.sam2.modeling.position_encoding.PositionEmbeddingSine
64
+ num_pos_feats: 64
65
+ normalize: true
66
+ scale: null
67
+ temperature: 10000
68
+ mask_downsampler:
69
+ _target_: third_parts.sam2.modeling.memory_encoder.MaskDownSampler
70
+ kernel_size: 3
71
+ stride: 2
72
+ padding: 1
73
+ fuser:
74
+ _target_: third_parts.sam2.modeling.memory_encoder.Fuser
75
+ layer:
76
+ _target_: third_parts.sam2.modeling.memory_encoder.CXBlock
77
+ dim: 256
78
+ kernel_size: 7
79
+ padding: 3
80
+ layer_scale_init_value: 1e-6
81
+ use_dwconv: True # depth-wise convs
82
+ num_layers: 2
83
+
84
+ num_maskmem: 7
85
+ image_size: 1024
86
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87
+ sigmoid_scale_for_mem_enc: 20.0
88
+ sigmoid_bias_for_mem_enc: -10.0
89
+ use_mask_input_as_output_without_sam: true
90
+ # Memory
91
+ directly_add_no_mem_embed: true
92
+ # use high-resolution feature map in the SAM mask decoder
93
+ use_high_res_features_in_sam: true
94
+ # output 3 masks on the first click on initial conditioning frames
95
+ multimask_output_in_sam: true
96
+ # SAM heads
97
+ iou_prediction_use_sigmoid: True
98
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
99
+ use_obj_ptrs_in_encoder: true
100
+ add_tpos_enc_to_obj_ptrs: false
101
+ only_obj_ptrs_in_the_past_for_eval: true
102
+ # object occlusion prediction
103
+ pred_obj_scores: true
104
+ pred_obj_scores_mlp: true
105
+ fixed_no_obj_ptr: true
106
+ # multimask tracking settings
107
+ multimask_output_for_tracking: true
108
+ use_multimask_token_for_obj_ptr: true
109
+ multimask_min_pt_num: 0
110
+ multimask_max_pt_num: 1
111
+ use_mlp_for_obj_ptr_proj: true
112
+ # Compilation flag
113
+ compile_image_encoder: False
third_parts/sam2/sam2_configs/sam2_hiera_l.yaml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: third_parts.sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: third_parts.sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: third_parts.sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 144
12
+ num_heads: 2
13
+ stages: [2, 6, 36, 4]
14
+ global_att_blocks: [23, 33, 43]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ window_spec: [8, 4, 16, 8]
17
+ neck:
18
+ _target_: third_parts.sam2.modeling.backbones.image_encoder.FpnNeck
19
+ position_encoding:
20
+ _target_: third_parts.sam2.modeling.position_encoding.PositionEmbeddingSine
21
+ num_pos_feats: 256
22
+ normalize: true
23
+ scale: null
24
+ temperature: 10000
25
+ d_model: 256
26
+ backbone_channel_list: [1152, 576, 288, 144]
27
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28
+ fpn_interp_model: nearest
29
+
30
+ memory_attention:
31
+ _target_: third_parts.sam2.modeling.memory_attention.MemoryAttention
32
+ d_model: 256
33
+ pos_enc_at_input: true
34
+ layer:
35
+ _target_: third_parts.sam2.modeling.memory_attention.MemoryAttentionLayer
36
+ activation: relu
37
+ dim_feedforward: 2048
38
+ dropout: 0.1
39
+ pos_enc_at_attn: false
40
+ self_attention:
41
+ _target_: third_parts.sam2.modeling.sam.transformer.RoPEAttention
42
+ rope_theta: 10000.0
43
+ feat_sizes: [32, 32]
44
+ embedding_dim: 256
45
+ num_heads: 1
46
+ downsample_rate: 1
47
+ dropout: 0.1
48
+ d_model: 256
49
+ pos_enc_at_cross_attn_keys: true
50
+ pos_enc_at_cross_attn_queries: false
51
+ cross_attention:
52
+ _target_: third_parts.sam2.modeling.sam.transformer.RoPEAttention
53
+ rope_theta: 10000.0
54
+ feat_sizes: [32, 32]
55
+ rope_k_repeat: True
56
+ embedding_dim: 256
57
+ num_heads: 1
58
+ downsample_rate: 1
59
+ dropout: 0.1
60
+ kv_in_dim: 64
61
+ num_layers: 4
62
+
63
+ memory_encoder:
64
+ _target_: third_parts.sam2.modeling.memory_encoder.MemoryEncoder
65
+ out_dim: 64
66
+ position_encoding:
67
+ _target_: third_parts.sam2.modeling.position_encoding.PositionEmbeddingSine
68
+ num_pos_feats: 64
69
+ normalize: true
70
+ scale: null
71
+ temperature: 10000
72
+ mask_downsampler:
73
+ _target_: third_parts.sam2.modeling.memory_encoder.MaskDownSampler
74
+ kernel_size: 3
75
+ stride: 2
76
+ padding: 1
77
+ fuser:
78
+ _target_: third_parts.sam2.modeling.memory_encoder.Fuser
79
+ layer:
80
+ _target_: third_parts.sam2.modeling.memory_encoder.CXBlock
81
+ dim: 256
82
+ kernel_size: 7
83
+ padding: 3
84
+ layer_scale_init_value: 1e-6
85
+ use_dwconv: True # depth-wise convs
86
+ num_layers: 2
87
+
88
+ num_maskmem: 7
89
+ image_size: 1024
90
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: false
105
+ only_obj_ptrs_in_the_past_for_eval: true
106
+ # object occlusion prediction
107
+ pred_obj_scores: true
108
+ pred_obj_scores_mlp: true
109
+ fixed_no_obj_ptr: true
110
+ # multimask tracking settings
111
+ multimask_output_for_tracking: true
112
+ use_multimask_token_for_obj_ptr: true
113
+ multimask_min_pt_num: 0
114
+ multimask_max_pt_num: 1
115
+ use_mlp_for_obj_ptr_proj: true
116
+ # Compilation flag
117
+ compile_image_encoder: False
third_parts/sam2/sam2_configs/sam2_hiera_s.yaml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: third_parts.sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: third_parts.sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: third_parts.sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 11, 2]
14
+ global_att_blocks: [7, 10, 13]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: third_parts.sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: third_parts.sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: third_parts.sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: third_parts.sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: third_parts.sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: third_parts.sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: third_parts.sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: third_parts.sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: third_parts.sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: third_parts.sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: third_parts.sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ sigmoid_scale_for_mem_enc: 20.0
91
+ sigmoid_bias_for_mem_enc: -10.0
92
+ use_mask_input_as_output_without_sam: true
93
+ # Memory
94
+ directly_add_no_mem_embed: true
95
+ # use high-resolution feature map in the SAM mask decoder
96
+ use_high_res_features_in_sam: true
97
+ # output 3 masks on the first click on initial conditioning frames
98
+ multimask_output_in_sam: true
99
+ # SAM heads
100
+ iou_prediction_use_sigmoid: True
101
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
102
+ use_obj_ptrs_in_encoder: true
103
+ add_tpos_enc_to_obj_ptrs: false
104
+ only_obj_ptrs_in_the_past_for_eval: true
105
+ # object occlusion prediction
106
+ pred_obj_scores: true
107
+ pred_obj_scores_mlp: true
108
+ fixed_no_obj_ptr: true
109
+ # multimask tracking settings
110
+ multimask_output_for_tracking: true
111
+ use_multimask_token_for_obj_ptr: true
112
+ multimask_min_pt_num: 0
113
+ multimask_max_pt_num: 1
114
+ use_mlp_for_obj_ptr_proj: true
115
+ # Compilation flag
116
+ compile_image_encoder: False
third_parts/sam2/sam2_configs/sam2_hiera_t.yaml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: third_parts.sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: third_parts.sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: third_parts.sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 7, 2]
14
+ global_att_blocks: [5, 7, 9]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: third_parts.sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: third_parts.sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: third_parts.sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: third_parts.sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: third_parts.sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: third_parts.sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: third_parts.sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: third_parts.sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: third_parts.sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: third_parts.sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: third_parts.sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ # SAM decoder
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: false
105
+ only_obj_ptrs_in_the_past_for_eval: true
106
+ # object occlusion prediction
107
+ pred_obj_scores: true
108
+ pred_obj_scores_mlp: true
109
+ fixed_no_obj_ptr: true
110
+ # multimask tracking settings
111
+ multimask_output_for_tracking: true
112
+ use_multimask_token_for_obj_ptr: true
113
+ multimask_min_pt_num: 0
114
+ multimask_max_pt_num: 1
115
+ use_mlp_for_obj_ptr_proj: true
116
+ # Compilation flag
117
+ # HieraT does not currently support compilation, should always be set to False
118
+ compile_image_encoder: False
third_parts/sam2/sam2_image_predictor.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+
9
+ from typing import List, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ from PIL.Image import Image
14
+
15
+ from third_parts.sam2.modeling.sam2_base import SAM2Base
16
+
17
+ from third_parts.sam2.utils.transforms import SAM2Transforms
18
+
19
+
20
+ class SAM2ImagePredictor:
21
+ def __init__(
22
+ self,
23
+ sam_model: SAM2Base,
24
+ mask_threshold=0.0,
25
+ max_hole_area=0.0,
26
+ max_sprinkle_area=0.0,
27
+ ) -> None:
28
+ """
29
+ Uses SAM-2 to calculate the image embedding for an image, and then
30
+ allow repeated, efficient mask prediction given prompts.
31
+
32
+ Arguments:
33
+ sam_model (Sam-2): The model to use for mask prediction.
34
+ mask_threshold (float): The threshold to use when converting mask logits
35
+ to binary masks. Masks are thresholded at 0 by default.
36
+ fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to
37
+ the maximum area of fill_hole_area in low_res_masks.
38
+ """
39
+ super().__init__()
40
+ self.model = sam_model
41
+ self._transforms = SAM2Transforms(
42
+ resolution=self.model.image_size,
43
+ mask_threshold=mask_threshold,
44
+ max_hole_area=max_hole_area,
45
+ max_sprinkle_area=max_sprinkle_area,
46
+ )
47
+
48
+ # Predictor state
49
+ self._is_image_set = False
50
+ self._features = None
51
+ self._orig_hw = None
52
+ # Whether the predictor is set for single image or a batch of images
53
+ self._is_batch = False
54
+
55
+ # Predictor config
56
+ self.mask_threshold = mask_threshold
57
+
58
+ # Spatial dim for backbone feature maps
59
+ self._bb_feat_sizes = [
60
+ (256, 256),
61
+ (128, 128),
62
+ (64, 64),
63
+ ]
64
+
65
+ @torch.no_grad()
66
+ def set_image(
67
+ self,
68
+ image: Union[np.ndarray, Image],
69
+ ) -> None:
70
+ """
71
+ Calculates the image embeddings for the provided image, allowing
72
+ masks to be predicted with the 'predict' method.
73
+
74
+ Arguments:
75
+ image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image
76
+ with pixel values in [0, 255].
77
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
78
+ """
79
+ self.reset_predictor()
80
+ # Transform the image to the form expected by the model
81
+ if isinstance(image, np.ndarray):
82
+ logging.info("For numpy array image, we assume (HxWxC) format")
83
+ self._orig_hw = [image.shape[:2]]
84
+ elif isinstance(image, Image):
85
+ w, h = image.size
86
+ self._orig_hw = [(h, w)]
87
+ else:
88
+ raise NotImplementedError("Image format not supported")
89
+
90
+ input_image = self._transforms(image)
91
+ input_image = input_image[None, ...].to(self.device)
92
+
93
+ assert (
94
+ len(input_image.shape) == 4 and input_image.shape[1] == 3
95
+ ), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
96
+ logging.info("Computing image embeddings for the provided image...")
97
+ backbone_out = self.model.forward_image(input_image)
98
+ _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
99
+ # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
100
+ if self.model.directly_add_no_mem_embed:
101
+ vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
102
+
103
+ feats = [
104
+ feat.permute(1, 2, 0).view(1, -1, *feat_size)
105
+ for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
106
+ ][::-1]
107
+ self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
108
+ self._is_image_set = True
109
+ logging.info("Image embeddings computed.")
110
+
111
+ @torch.no_grad()
112
+ def set_image_batch(
113
+ self,
114
+ image_list: List[Union[np.ndarray]],
115
+ ) -> None:
116
+ """
117
+ Calculates the image embeddings for the provided image batch, allowing
118
+ masks to be predicted with the 'predict_batch' method.
119
+
120
+ Arguments:
121
+ image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray
122
+ with pixel values in [0, 255].
123
+ """
124
+ self.reset_predictor()
125
+ assert isinstance(image_list, list)
126
+ self._orig_hw = []
127
+ for image in image_list:
128
+ assert isinstance(
129
+ image, np.ndarray
130
+ ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
131
+ self._orig_hw.append(image.shape[:2])
132
+ # Transform the image to the form expected by the model
133
+ img_batch = self._transforms.forward_batch(image_list)
134
+ img_batch = img_batch.to(self.device)
135
+ batch_size = img_batch.shape[0]
136
+ assert (
137
+ len(img_batch.shape) == 4 and img_batch.shape[1] == 3
138
+ ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
139
+ logging.info("Computing image embeddings for the provided images...")
140
+ backbone_out = self.model.forward_image(img_batch)
141
+ _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
142
+ # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
143
+ if self.model.directly_add_no_mem_embed:
144
+ vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
145
+
146
+ feats = [
147
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
148
+ for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
149
+ ][::-1]
150
+ self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
151
+ self._is_image_set = True
152
+ self._is_batch = True
153
+ logging.info("Image embeddings computed.")
154
+
155
+ def predict_batch(
156
+ self,
157
+ point_coords_batch: List[np.ndarray] = None,
158
+ point_labels_batch: List[np.ndarray] = None,
159
+ box_batch: List[np.ndarray] = None,
160
+ mask_input_batch: List[np.ndarray] = None,
161
+ multimask_output: bool = True,
162
+ return_logits: bool = False,
163
+ normalize_coords=True,
164
+ ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
165
+ """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
166
+ It returns a tupele of lists of masks, ious, and low_res_masks_logits.
167
+ """
168
+ assert self._is_batch, "This function should only be used when in batched mode"
169
+ if not self._is_image_set:
170
+ raise RuntimeError(
171
+ "An image must be set with .set_image_batch(...) before mask prediction."
172
+ )
173
+ num_images = len(self._features["image_embed"])
174
+ all_masks = []
175
+ all_ious = []
176
+ all_low_res_masks = []
177
+ for img_idx in range(num_images):
178
+ # Transform input prompts
179
+ point_coords = (
180
+ point_coords_batch[img_idx] if point_coords_batch is not None else None
181
+ )
182
+ point_labels = (
183
+ point_labels_batch[img_idx] if point_labels_batch is not None else None
184
+ )
185
+ box = box_batch[img_idx] if box_batch is not None else None
186
+ mask_input = (
187
+ mask_input_batch[img_idx] if mask_input_batch is not None else None
188
+ )
189
+ mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
190
+ point_coords,
191
+ point_labels,
192
+ box,
193
+ mask_input,
194
+ normalize_coords,
195
+ img_idx=img_idx,
196
+ )
197
+ masks, iou_predictions, low_res_masks = self._predict(
198
+ unnorm_coords,
199
+ labels,
200
+ unnorm_box,
201
+ mask_input,
202
+ multimask_output,
203
+ return_logits=return_logits,
204
+ img_idx=img_idx,
205
+ )
206
+ masks_np = masks.squeeze(0).float().detach().cpu().numpy()
207
+ iou_predictions_np = (
208
+ iou_predictions.squeeze(0).float().detach().cpu().numpy()
209
+ )
210
+ low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
211
+ all_masks.append(masks_np)
212
+ all_ious.append(iou_predictions_np)
213
+ all_low_res_masks.append(low_res_masks_np)
214
+
215
+ return all_masks, all_ious, all_low_res_masks
216
+
217
+ def predict(
218
+ self,
219
+ point_coords: Optional[np.ndarray] = None,
220
+ point_labels: Optional[np.ndarray] = None,
221
+ box: Optional[np.ndarray] = None,
222
+ mask_input: Optional[np.ndarray] = None,
223
+ multimask_output: bool = True,
224
+ return_logits: bool = False,
225
+ normalize_coords=True,
226
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
227
+ """
228
+ Predict masks for the given input prompts, using the currently set image.
229
+
230
+ Arguments:
231
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
232
+ model. Each point is in (X,Y) in pixels.
233
+ point_labels (np.ndarray or None): A length N array of labels for the
234
+ point prompts. 1 indicates a foreground point and 0 indicates a
235
+ background point.
236
+ box (np.ndarray or None): A length 4 array given a box prompt to the
237
+ model, in XYXY format.
238
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
239
+ coming from a previous prediction iteration. Has form 1xHxW, where
240
+ for SAM, H=W=256.
241
+ multimask_output (bool): If true, the model will return three masks.
242
+ For ambiguous input prompts (such as a single click), this will often
243
+ produce better masks than a single prediction. If only a single
244
+ mask is needed, the model's predicted quality score can be used
245
+ to select the best mask. For non-ambiguous prompts, such as multiple
246
+ input prompts, multimask_output=False can give better results.
247
+ return_logits (bool): If true, returns un-thresholded masks logits
248
+ instead of a binary mask.
249
+ normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions.
250
+
251
+ Returns:
252
+ (np.ndarray): The output masks in CxHxW format, where C is the
253
+ number of masks, and (H, W) is the original image size.
254
+ (np.ndarray): An array of length C containing the model's
255
+ predictions for the quality of each mask.
256
+ (np.ndarray): An array of shape CxHxW, where C is the number
257
+ of masks and H=W=256. These low resolution logits can be passed to
258
+ a subsequent iteration as mask input.
259
+ """
260
+ if not self._is_image_set:
261
+ raise RuntimeError(
262
+ "An image must be set with .set_image(...) before mask prediction."
263
+ )
264
+
265
+ # Transform input prompts
266
+
267
+ mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
268
+ point_coords, point_labels, box, mask_input, normalize_coords
269
+ )
270
+
271
+ masks, iou_predictions, low_res_masks = self._predict(
272
+ unnorm_coords,
273
+ labels,
274
+ unnorm_box,
275
+ mask_input,
276
+ multimask_output,
277
+ return_logits=return_logits,
278
+ )
279
+
280
+ masks_np = masks.squeeze(0).float().detach().cpu().numpy()
281
+ iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy()
282
+ low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
283
+ return masks_np, iou_predictions_np, low_res_masks_np
284
+
285
+ def _prep_prompts(
286
+ self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
287
+ ):
288
+
289
+ unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
290
+ if point_coords is not None:
291
+ assert (
292
+ point_labels is not None
293
+ ), "point_labels must be supplied if point_coords is supplied."
294
+ point_coords = torch.as_tensor(
295
+ point_coords, dtype=torch.float, device=self.device
296
+ )
297
+ unnorm_coords = self._transforms.transform_coords(
298
+ point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
299
+ )
300
+ labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
301
+ if len(unnorm_coords.shape) == 2:
302
+ unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
303
+ if box is not None:
304
+ box = torch.as_tensor(box, dtype=torch.float, device=self.device)
305
+ unnorm_box = self._transforms.transform_boxes(
306
+ box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
307
+ ) # Bx2x2
308
+ if mask_logits is not None:
309
+ mask_input = torch.as_tensor(
310
+ mask_logits, dtype=torch.float, device=self.device
311
+ )
312
+ if len(mask_input.shape) == 3:
313
+ mask_input = mask_input[None, :, :, :]
314
+ return mask_input, unnorm_coords, labels, unnorm_box
315
+
316
+ @torch.no_grad()
317
+ def _predict(
318
+ self,
319
+ point_coords: Optional[torch.Tensor],
320
+ point_labels: Optional[torch.Tensor],
321
+ boxes: Optional[torch.Tensor] = None,
322
+ mask_input: Optional[torch.Tensor] = None,
323
+ multimask_output: bool = True,
324
+ return_logits: bool = False,
325
+ img_idx: int = -1,
326
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
327
+ """
328
+ Predict masks for the given input prompts, using the currently set image.
329
+ Input prompts are batched torch tensors and are expected to already be
330
+ transformed to the input frame using SAM2Transforms.
331
+
332
+ Arguments:
333
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
334
+ model. Each point is in (X,Y) in pixels.
335
+ point_labels (torch.Tensor or None): A BxN array of labels for the
336
+ point prompts. 1 indicates a foreground point and 0 indicates a
337
+ background point.
338
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
339
+ model, in XYXY format.
340
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
341
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
342
+ for SAM, H=W=256. Masks returned by a previous iteration of the
343
+ predict method do not need further transformation.
344
+ multimask_output (bool): If true, the model will return three masks.
345
+ For ambiguous input prompts (such as a single click), this will often
346
+ produce better masks than a single prediction. If only a single
347
+ mask is needed, the model's predicted quality score can be used
348
+ to select the best mask. For non-ambiguous prompts, such as multiple
349
+ input prompts, multimask_output=False can give better results.
350
+ return_logits (bool): If true, returns un-thresholded masks logits
351
+ instead of a binary mask.
352
+
353
+ Returns:
354
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
355
+ number of masks, and (H, W) is the original image size.
356
+ (torch.Tensor): An array of shape BxC containing the model's
357
+ predictions for the quality of each mask.
358
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
359
+ of masks and H=W=256. These low res logits can be passed to
360
+ a subsequent iteration as mask input.
361
+ """
362
+ if not self._is_image_set:
363
+ raise RuntimeError(
364
+ "An image must be set with .set_image(...) before mask prediction."
365
+ )
366
+
367
+ if point_coords is not None:
368
+ concat_points = (point_coords, point_labels)
369
+ else:
370
+ concat_points = None
371
+
372
+ # Embed prompts
373
+ if boxes is not None:
374
+ box_coords = boxes.reshape(-1, 2, 2)
375
+ box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
376
+ box_labels = box_labels.repeat(boxes.size(0), 1)
377
+ # we merge "boxes" and "points" into a single "concat_points" input (where
378
+ # boxes are added at the beginning) to sam_prompt_encoder
379
+ if concat_points is not None:
380
+ concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
381
+ concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
382
+ concat_points = (concat_coords, concat_labels)
383
+ else:
384
+ concat_points = (box_coords, box_labels)
385
+
386
+ sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
387
+ points=concat_points,
388
+ boxes=None,
389
+ masks=mask_input,
390
+ )
391
+
392
+ # Predict masks
393
+ batched_mode = (
394
+ concat_points is not None and concat_points[0].shape[0] > 1
395
+ ) # multi object prediction
396
+ high_res_features = [
397
+ feat_level[img_idx].unsqueeze(0)
398
+ for feat_level in self._features["high_res_feats"]
399
+ ]
400
+ low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
401
+ image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
402
+ image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
403
+ sparse_prompt_embeddings=sparse_embeddings,
404
+ dense_prompt_embeddings=dense_embeddings,
405
+ multimask_output=multimask_output,
406
+ repeat_image=batched_mode,
407
+ high_res_features=high_res_features,
408
+ )
409
+
410
+ # Upscale the masks to the original image resolution
411
+ masks = self._transforms.postprocess_masks(
412
+ low_res_masks, self._orig_hw[img_idx]
413
+ )
414
+ low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
415
+ if not return_logits:
416
+ masks = masks > self.mask_threshold
417
+
418
+ return masks, iou_predictions, low_res_masks
419
+
420
+ def get_image_embedding(self) -> torch.Tensor:
421
+ """
422
+ Returns the image embeddings for the currently set image, with
423
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
424
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
425
+ """
426
+ if not self._is_image_set:
427
+ raise RuntimeError(
428
+ "An image must be set with .set_image(...) to generate an embedding."
429
+ )
430
+ assert (
431
+ self._features is not None
432
+ ), "Features must exist if an image has been set."
433
+ return self._features["image_embed"]
434
+
435
+ @property
436
+ def device(self) -> torch.device:
437
+ return self.model.device
438
+
439
+ def reset_predictor(self) -> None:
440
+ """
441
+ Resets the image embeddings and other state variables.
442
+ """
443
+ self._is_image_set = False
444
+ self._features = None
445
+ self._orig_hw = None
446
+ self._is_batch = False