buyun commited on
Commit
3a7c6d9
·
verified ·
1 Parent(s): 4ad74b5

Add files using upload-large-folder tool

Browse files
Files changed (2) hide show
  1. processing_step3.py +465 -0
  2. processor_config.json +2 -2
processing_step3.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
2
+ import math
3
+ from typing import Iterable, Optional, Tuple, List, TypedDict, Literal, Union, overload
4
+
5
+ from PIL import Image
6
+ import torch
7
+ import numpy as np
8
+ import torchvision
9
+ from transformers.image_utils import ImageInput, make_nested_list_of_images
10
+ from torch import nn
11
+ from torch.nn import functional as F, LayerNorm
12
+ from torchvision.transforms.functional import InterpolationMode
13
+ from transformers.activations import ACT2FN
14
+ from torchvision import transforms
15
+ from torchvision.transforms.functional import InterpolationMode
16
+ from transformers.feature_extraction_utils import BatchFeature
17
+ from transformers.image_utils import ImageInput
18
+ from transformers.processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
19
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
20
+ from transformers.utils import logging
21
+ from transformers.video_utils import VideoInput
22
+ from transformers import BatchFeature, PretrainedConfig, TensorType
23
+ from transformers.image_utils import make_flat_list_of_images
24
+ from math import ceil
25
+ from itertools import product
26
+ from transformers import LlamaTokenizerFast
27
+
28
+
29
+ MAX_IMAGE_SIZE: int = 3024
30
+
31
+ class Step3VLImagePixelInputs(TypedDict):
32
+ type: Literal["pixel_values"]
33
+ pixel_values: torch.Tensor
34
+ patch_pixel_values: Optional[torch.Tensor]
35
+ num_patches: list[int]
36
+
37
+
38
+ class Step3VLImageEmbeddingInputs(TypedDict):
39
+ type: Literal["image_embeds"]
40
+ image_embeds: torch.Tensor
41
+
42
+
43
+ Step3VLImageInputs = Union[Step3VLImagePixelInputs,
44
+ Step3VLImageEmbeddingInputs]
45
+
46
+ ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None]
47
+
48
+
49
+ class GPUToTensor(torch.nn.Module):
50
+
51
+ def forward(self, raw_image: Union[np.ndarray,
52
+ Image.Image]) -> torch.Tensor:
53
+ if isinstance(raw_image, Image.Image):
54
+ return transforms.ToTensor()(raw_image)
55
+ if raw_image.ndim == 2:
56
+ raw_image = raw_image[:, :, None].repeat(3, -1)
57
+ if torch.cuda.is_available():
58
+ device = torch.device("cuda")
59
+ else:
60
+ device = torch.device("cpu")
61
+ image_tensor = torch.from_numpy(raw_image).to(device)
62
+ image_tensor = torch.permute(image_tensor, (2, 0, 1)).contiguous()
63
+ if image_tensor.dtype == torch.uint8:
64
+ image_tensor = image_tensor.to(torch.float32).div(255)
65
+ return image_tensor
66
+
67
+ class Step3VisionProcessor:
68
+
69
+ def __init__(self, size, interpolation_mode="bicubic", patch_size=None):
70
+ mean = [0.48145466, 0.4578275, 0.40821073]
71
+ std = [0.26862954, 0.26130258, 0.27577711]
72
+ patch_size = patch_size if patch_size is not None else size
73
+
74
+ self.transform = transforms.Compose([
75
+ GPUToTensor(),
76
+ transforms.Normalize(mean, std),
77
+ transforms.Resize(
78
+ (size, size),
79
+ interpolation=InterpolationMode.BICUBIC if interpolation_mode
80
+ == "bicubic" else InterpolationMode.BILINEAR,
81
+ antialias=True),
82
+ ])
83
+
84
+ self.patch_transform = transforms.Compose([
85
+ GPUToTensor(),
86
+ transforms.Normalize(mean, std),
87
+ transforms.Resize(
88
+ (patch_size, patch_size),
89
+ interpolation=InterpolationMode.BICUBIC if interpolation_mode
90
+ == "bicubic" else InterpolationMode.BILINEAR,
91
+ antialias=True),
92
+ ]) if patch_size is not None else None
93
+
94
+ def __call__(self, image, is_patch=False):
95
+ if is_patch:
96
+ return {"pixel_values": self.patch_transform(image).unsqueeze(0)}
97
+ else:
98
+ return {"pixel_values": self.transform(image).unsqueeze(0)}
99
+
100
+ class ImagePatcher:
101
+ def determine_window_size(self, long: int, short: int) -> int:
102
+ if long <= 728:
103
+ return short if long / short > 1.5 else 0
104
+ return min(short, 504) if long / short > 4 else 504
105
+ def slide_window(
106
+ self,
107
+ width: int,
108
+ height: int,
109
+ sizes: list[tuple[int, int]],
110
+ steps: list[tuple[int, int]],
111
+ img_rate_thr: float = 0.6,
112
+ ) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]:
113
+ assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1"
114
+ windows = []
115
+ # Sliding windows.
116
+ for size, step in zip(sizes, steps):
117
+ size_w, size_h = size
118
+ step_w, step_h = step
119
+
120
+ x_num = 1 if width <= size_w else ceil((width - size_w) / step_w +
121
+ 1)
122
+ x_start = [step_w * i for i in range(x_num)]
123
+ if len(x_start) > 1 and x_start[-1] + size_w > width:
124
+ x_start[-1] = width - size_w
125
+
126
+ y_num = 1 if height <= size_h else ceil((height - size_h) /
127
+ step_h + 1)
128
+ y_start = [step_h * i for i in range(y_num)]
129
+ if len(y_start) > 1 and y_start[-1] + size_h > height:
130
+ y_start[-1] = height - size_h
131
+
132
+ start = np.array(list(product(y_start, x_start)), dtype=int)
133
+ start[:, [0, 1]] = start[:, [1, 0]]
134
+ windows.append(np.concatenate([start, start + size], axis=1))
135
+ windows = np.concatenate(windows, axis=0)
136
+
137
+ return [(int(box[0]), int(box[1]), int(box[2] - box[0]),
138
+ int(box[3] - box[1])) for box in windows], (x_num, y_num)
139
+
140
+ def square_pad(self, img: Image.Image) -> Image.Image:
141
+ w, h = img.size
142
+ if w == h:
143
+ return img
144
+ size = max(w, h)
145
+ padded = Image.new(img.mode, (size, size), 0)
146
+ padded.paste(img, (0, 0))
147
+ return padded
148
+
149
+ def get_image_size_for_padding(self, img_width: int,
150
+ img_height: int) -> tuple[int, int]:
151
+ ratio = img_width / img_height
152
+ if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4):
153
+ new_size = max(img_height, img_width)
154
+ return new_size, new_size
155
+ return img_width, img_height
156
+
157
+ def get_image_size_for_preprocess(self, img_width: int,
158
+ img_height: int) -> tuple[int, int]:
159
+
160
+ if max(img_height, img_width) > MAX_IMAGE_SIZE:
161
+ scale_factor = MAX_IMAGE_SIZE / max(img_height, img_width)
162
+ img_width = int(img_width * scale_factor)
163
+ img_height = int(img_height * scale_factor)
164
+ return img_width, img_height
165
+
166
+ def get_image_size_for_crop(self, img_width: int, img_height: int,
167
+ window_size: int):
168
+ w_ratio = img_width / window_size
169
+ h_ratio = img_height / window_size
170
+
171
+ if w_ratio < 1:
172
+ width_new = img_width
173
+ else:
174
+ decimal_w = w_ratio - img_width // window_size
175
+ w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio)
176
+ width_new = window_size * w_ratio
177
+ if h_ratio < 1:
178
+ height_new = img_height
179
+ else:
180
+ decimal_h = h_ratio - img_height // window_size
181
+ h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio)
182
+ height_new = window_size * h_ratio
183
+ return int(width_new), int(height_new)
184
+
185
+ def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int):
186
+ target = img.crop((j, i, j + tw, i + th))
187
+ return target
188
+
189
+ def get_num_patches(self, img_width: int,
190
+ img_height: int) -> tuple[int, int]:
191
+ img_width, img_height = self.get_image_size_for_padding(
192
+ img_width, img_height)
193
+ img_width, img_height = self.get_image_size_for_preprocess(
194
+ img_width, img_height)
195
+ window_size = self.determine_window_size(max(img_height, img_width),
196
+ min(img_height, img_width))
197
+ if window_size == 0:
198
+ return 0, 0
199
+ else:
200
+ img_width, img_height = self.get_image_size_for_crop(
201
+ img_width, img_height, window_size)
202
+ center_list, (x_num, y_num) = self.slide_window(
203
+ img_width, img_height, [(window_size, window_size)],
204
+ [(window_size, window_size)])
205
+ full_rows = (len(center_list) - 1) // x_num + 1
206
+ if len(center_list) > 0 and len(center_list) % x_num == 0:
207
+ full_rows -= 1
208
+ return len(center_list), full_rows
209
+
210
+ def __call__(
211
+ self, img: Image.Image
212
+ ) -> tuple[Image.Image, list[Image.Image], list[bool] | None]:
213
+ img_width, img_height = img.size
214
+ new_img_width, new_img_height = self.get_image_size_for_padding(
215
+ img_width, img_height)
216
+ if new_img_width != img_width or new_img_height != img_height:
217
+ img = self.square_pad(img)
218
+ img_width, img_height = img.size
219
+
220
+ new_img_width, new_img_height = self.get_image_size_for_preprocess(
221
+ img_width, img_height)
222
+ img = img.resize((new_img_width, new_img_height),
223
+ Image.Resampling.BILINEAR)
224
+ window_size = self.determine_window_size(
225
+ max(new_img_height, new_img_width),
226
+ min(new_img_height, new_img_width))
227
+
228
+ if window_size == 0:
229
+ return img, [], None
230
+ else:
231
+ new_img_width, new_img_height = self.get_image_size_for_crop(
232
+ new_img_width, new_img_height, window_size)
233
+ if (new_img_width, new_img_height) != (img_width, img_height):
234
+ img_for_crop = img.resize((new_img_width, new_img_height),
235
+ Image.Resampling.BILINEAR)
236
+ else:
237
+ img_for_crop = img
238
+
239
+ patches = []
240
+ newlines = []
241
+ center_list, (x_num, y_num) = self.slide_window(
242
+ new_img_width, new_img_height, [(window_size, window_size)],
243
+ [(window_size, window_size)])
244
+ for patch_id, center_lf_point in enumerate(center_list):
245
+ x, y, patch_w, patch_h = center_lf_point
246
+ big_patch = self.patch_crop(img_for_crop, y, x, patch_h,
247
+ patch_w)
248
+ patches.append(big_patch)
249
+ if (patch_id + 1) % x_num == 0:
250
+ newlines.append(patch_id)
251
+
252
+ if newlines and newlines[-1] == len(patches) - 1:
253
+ newlines.pop()
254
+
255
+ return img, patches, [i in newlines for i in range(len(patches))] if len(patches) > 0 else None
256
+
257
+ class Step3VLProcessor(ProcessorMixin):
258
+ attributes = ["tokenizer"]
259
+ tokenizer_class = "AutoTokenizer"
260
+
261
+ def __init__(
262
+ self,
263
+ tokenizer,
264
+ chat_template=None,
265
+ **kwargs
266
+ ) -> None:
267
+ self.image_size = 728
268
+ self.patch_size = 504
269
+
270
+ self.image_preprocessor = Step3VisionProcessor(self.image_size,
271
+ "bilinear",
272
+ self.patch_size)
273
+
274
+ self.num_image_feature_size = 169
275
+ self.num_patch_feature_size = 81
276
+ self.image_token = "<im_patch>"
277
+ self.image_feature_placeholder = (self.image_token *
278
+ self.num_image_feature_size)
279
+ self.patch_feature_placeholder = (self.image_token *
280
+ self.num_patch_feature_size)
281
+
282
+ self.patcher = ImagePatcher()
283
+ super().__init__(tokenizer=tokenizer, chat_template=chat_template, **kwargs)
284
+
285
+ @property
286
+ def image_token_id(self) -> int:
287
+ return self.tokenizer.get_vocab()[self.image_token]
288
+
289
+ def get_num_image_tokens(self, img_width: int, img_height: int) -> int:
290
+ num_patches, num_newlines = self.patcher.get_num_patches(
291
+ img_width, img_height)
292
+
293
+ return num_patches * (
294
+ self.num_patch_feature_size +
295
+ 2) + self.num_image_feature_size + 2 + num_newlines
296
+
297
+ def _split_images(self,
298
+ images: list[Image.Image]) -> list[ImageWithPatches]:
299
+ result = []
300
+ for img in images:
301
+ result.append(self.patcher(img))
302
+ return result
303
+
304
+ def _convert_images_to_pixel_values(
305
+ self,
306
+ images: list[Image.Image],
307
+ is_patch: bool = False,
308
+ ) -> list[torch.Tensor]:
309
+ return [
310
+ self.image_preprocessor(img, is_patch=is_patch)["pixel_values"]
311
+ for img in images
312
+ ]
313
+
314
+ def _get_patch_repl(
315
+ self,
316
+ num_patches: int,
317
+ patch_newline_mask: list[bool] | None,
318
+ ) -> tuple[str, list[int]]:
319
+ text = ""
320
+ token_ids = []
321
+ for i in range(num_patches):
322
+ assert len(patch_newline_mask) == num_patches
323
+ text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>"
324
+ token_ids.extend(
325
+ [self.tokenizer.convert_tokens_to_ids("<patch_start>")] +
326
+ [self.image_token_id] * self.num_patch_feature_size +
327
+ [self.tokenizer.convert_tokens_to_ids("<patch_end>")])
328
+ if patch_newline_mask and patch_newline_mask[i]:
329
+ text += "<patch_newline>"
330
+ token_ids.append(
331
+ self.tokenizer.convert_tokens_to_ids("<patch_newline>"))
332
+ return text, token_ids
333
+
334
+ def _get_image_repl(
335
+ self,
336
+ num_images: int,
337
+ ) -> tuple[str, list[int]]:
338
+ text = f"<im_start>{self.image_feature_placeholder}<im_end>"
339
+ token_ids = [
340
+ self.tokenizer.convert_tokens_to_ids("<im_start>")
341
+ ] + [self.image_token_id] * self.num_image_feature_size + [
342
+ self.tokenizer.convert_tokens_to_ids("<im_end>")
343
+ ]
344
+ return text * num_images, token_ids * num_images
345
+
346
+ def _get_image_repl_features(
347
+ self,
348
+ num_images: int,
349
+ num_patches: int,
350
+ patch_new_line_idx: Optional[list[bool]],
351
+ ) -> tuple[str, list[int]]:
352
+ if num_patches > 0:
353
+ patch_repl, patch_repl_ids = self._get_patch_repl(
354
+ num_patches, patch_new_line_idx)
355
+ else:
356
+ patch_repl = ""
357
+ patch_repl_ids = []
358
+ image_repl, image_repl_ids = self._get_image_repl(num_images)
359
+ return patch_repl + image_repl, patch_repl_ids + image_repl_ids
360
+
361
+ def replace_placeholder(self, text: str, placeholder: str,
362
+ repls: list[str]) -> str:
363
+ parts = text.split(placeholder)
364
+
365
+ if len(parts) - 1 != len(repls):
366
+ raise ValueError(
367
+ "The number of placeholders does not match the number of replacements." # noqa: E501
368
+ )
369
+
370
+ result = [parts[0]]
371
+ for i, repl in enumerate(repls):
372
+ result.append(repl)
373
+ result.append(parts[i + 1])
374
+
375
+ return "".join(result)
376
+
377
+ def __call__(
378
+ self,
379
+ text: Optional[Union[str, list[str]]] = None,
380
+ images: Optional[Union[Image.Image, list[Image.Image]]] = None,
381
+ return_tensors: Optional[Union[str, TensorType]] = None,
382
+ **kwargs,
383
+ ) -> BatchFeature:
384
+ if text is None:
385
+ text = []
386
+ if not isinstance(text, list):
387
+ text = [text]
388
+ if images is None:
389
+ images = []
390
+ elif not isinstance(images, list):
391
+ images = [images]
392
+ elif isinstance(images[0], list):
393
+ images = images[0]
394
+
395
+ if len(images) == 0:
396
+ image_inputs = {}
397
+ text_inputs = self.tokenizer(text)
398
+ else:
399
+ splitted_images_data = self._split_images(images)
400
+ pixel_values_lst = []
401
+ patch_pixel_values_lst = []
402
+ patch_newline_mask_lst = []
403
+ image_repl_str_lst = []
404
+ image_repl_ids_lst = []
405
+ num_patches = []
406
+ for raw_img, img_patches, patch_newline_mask in splitted_images_data: # noqa: E501
407
+ pixel_values_lst.extend(
408
+ self._convert_images_to_pixel_values([raw_img]))
409
+
410
+ if len(img_patches) > 0:
411
+ patch_pixel_values_lst.extend(
412
+ self._convert_images_to_pixel_values(img_patches,
413
+ is_patch=True))
414
+ num_patches.append(len(img_patches))
415
+
416
+ image_repl_str, image_repl_ids = self._get_image_repl_features(
417
+ 1, len(img_patches), patch_newline_mask)
418
+ image_repl_str_lst.append(image_repl_str)
419
+ image_repl_ids_lst.extend(image_repl_ids)
420
+
421
+ if patch_newline_mask is not None:
422
+ patch_newline_mask_lst.extend(patch_newline_mask)
423
+
424
+ image_inputs = {
425
+ "pixel_values": torch.cat(pixel_values_lst),
426
+ "num_patches": num_patches,
427
+ }
428
+ if patch_pixel_values_lst:
429
+ image_inputs["patch_pixel_values"] = torch.cat(
430
+ patch_pixel_values_lst)
431
+ if patch_newline_mask_lst:
432
+ image_inputs["patch_newline_mask"] = torch.tensor(
433
+ patch_newline_mask_lst, dtype=torch.bool)
434
+
435
+ text = [
436
+ self.replace_placeholder(t, self.image_token,
437
+ image_repl_str_lst) for t in text
438
+ ]
439
+ text_inputs = self.tokenizer(text)
440
+
441
+ return BatchFeature(
442
+ {
443
+ **text_inputs,
444
+ **image_inputs,
445
+ },
446
+ tensor_type=return_tensors,
447
+ )
448
+
449
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
450
+ def batch_decode(self, *args, **kwargs):
451
+ """
452
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
453
+ refer to the docstring of this method for more information.
454
+ """
455
+ return self.tokenizer.batch_decode(*args, **kwargs)
456
+
457
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
458
+ def decode(self, *args, **kwargs):
459
+ """
460
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
461
+ the docstring of this method for more information.
462
+ """
463
+ return self.tokenizer.decode(*args, **kwargs)
464
+
465
+ __all__ = ["Step3VLProcessor"]
processor_config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
  "auto_map": {
3
- "AutoProcessor": "processing_step3v.Step3VLProcessor"
4
  }
5
- }
 
1
  {
2
  "auto_map": {
3
+ "AutoProcessor": "processing_step3.Step3VLProcessor"
4
  }
5
+ }