File size: 2,530 Bytes
383af88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
SPDX-License-Identifier: MIT
"""

import numpy as np
import torch
from PIL import ImageOps
from torchvision import transforms
from torchvision.transforms.functional import resize

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)


class DolphinProcessor:
    def __init__(
        self,
        dp_config,
        tokenizer,
        **kwargs,
    ) -> None:

        self.tokenizer = tokenizer
        transform_args = kwargs.get("transform_args", {})
        self.max_length = transform_args.get("max_length", 2048)
        self.input_size = transform_args.get("input_size", [896, 896])  # height, width
        if isinstance(self.input_size, int):
            self.input_size = [self.input_size, self.input_size]

        try:
            self.answer_start_token = self.tokenizer._prompt_end_token
        except AttributeError as err:
            print('No answer_start_token found, use "" instead')
            self.answer_start_token = ""

        self.prefix_answer_space_flag = dp_config.get("prefix_answer_space_flag", True)
        self.suffix_prompt_space_flag = dp_config.get("suffix_prompt_space_flag", True)

        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)]
        )

    def process_prompt_for_inference(self, prompt):
        prompt = prompt.replace("<image>\n", "")
        if not prompt.startswith("<s>"):
            prompt = "<s>" + prompt
        message_ids = [self.tokenizer.encode(prompt, add_special_tokens=False)]
        ids = torch.from_numpy(np.hstack(message_ids, dtype=np.int32))
        return ids.unsqueeze(0)

    def process_image_for_inference(self, image, return_img_size=False):
        image = resize(image, min(self.input_size))

        image.thumbnail((self.input_size[1], self.input_size[0]))
        origin_w, origin_h = image.size

        delta_width = self.input_size[1] - image.width
        delta_height = self.input_size[0] - image.height
        pad_width = delta_width // 2
        pad_height = delta_height // 2
        padding = (
            pad_width,
            pad_height,
            delta_width - pad_width,
            delta_height - pad_height,
        )
        image = ImageOps.expand(image, padding)
        if return_img_size:
            return self.transform(image).unsqueeze(0), (origin_w, origin_h)
        return self.transform(image).unsqueeze(0)