File size: 5,534 Bytes
235555c
 
 
05d640e
235555c
05d640e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ef2cad
05d640e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ef2cad
 
 
 
 
 
 
 
 
05d640e
 
 
8ef2cad
05d640e
 
 
 
8ef2cad
05d640e
 
 
 
8ef2cad
05d640e
 
 
 
8ef2cad
05d640e
 
 
 
8ef2cad
05d640e
 
 
 
8ef2cad
05d640e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ef2cad
 
 
05d640e
 
 
 
 
 
8ef2cad
05d640e
235555c
8ef2cad
 
 
 
05d640e
 
8ef2cad
05d640e
8ef2cad
05d640e
 
235555c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05d640e
235555c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import torch
import torch.nn as nn

from transformers import PreTrainedModel, PretrainedConfig
from typing import Union

from .config import MoondreamConfig
from .moondream import MoondreamModel

# Files sometimes don't get loaded without these...
from .image_crops import *
from .vision import *
from .text import *
from .region import *
from .utils import *


def extract_question(text):
    prefix = "<image>\n\nQuestion: "
    suffix = "\n\nAnswer:"

    if text.startswith(prefix) and text.endswith(suffix):
        return text[len(prefix) : -len(suffix)]
    else:
        return None


class HfConfig(PretrainedConfig):
    _auto_class = "AutoConfig"
    model_type = "moondream1"

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.config = {}


class HfMoondream(PreTrainedModel):
    _auto_class = "AutoModelForCausalLM"
    config_class = HfConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = MoondreamModel(
            MoondreamConfig.from_dict(config.config), setup_caches=False
        )
        self._is_kv_cache_setup = False

    def _setup_caches(self):
        if not self._is_kv_cache_setup:
            self.model._setup_caches()
            self._is_kv_cache_setup = True

    @property
    def encode_image(self):
        self._setup_caches()
        return self.model.encode_image

    @property
    def query(self):
        self._setup_caches()
        return self.model.query

    @property
    def caption(self):
        self._setup_caches()
        return self.model.caption

    @property
    def detect(self):
        self._setup_caches()
        return self.model.detect

    @property
    def point(self):
        self._setup_caches()
        return self.model.point

    @property
    def detect_gaze(self):
        self._setup_caches()
        return self.model.detect_gaze

    def answer_question(
        self,
        image_embeds,
        question,
        tokenizer=None,
        chat_history="",
        result_queue=None,
        max_new_tokens=256,
        **kwargs
    ):
        answer = self.query(image_embeds, question)["answer"].strip()

        if result_queue is not None:
            result_queue.put(answer)
        return answer

    def batch_answer(self, images, prompts, tokenizer=None, **kwargs):
        answers = []
        for image, prompt in zip(images, prompts):
            answers.append(self.query(image, prompt)["answer"].strip())
        return answers

    def _unsupported_exception(self):
        raise NotImplementedError(
            "This method is not supported in the latest version of moondream. "
            "Consider upgrading to the updated API spec, or alternately pin "
            "to 'revision=2024-08-26'."
        )

    def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128, **kwargs):
        """
        Function definition remains unchanged for backwards compatibility.
        Be aware that tokenizer, max_new_takens, and kwargs are ignored.
        """
        prompt_extracted = extract_question(prompt)
        if prompt_extracted is not None:
            answer = self.model.query(
                image=image_embeds, question=prompt_extracted, stream=False
            )["answer"]
        else:
            image_embeds = self.encode_image(image_embeds)
            prompt_tokens = torch.tensor(
                [self.model.tokenizer.encode(prompt).ids],
                device=self.device,
            )

            def generator():
                for token in self.model._generate_answer(
                    prompt_tokens,
                    image_embeds.kv_cache,
                    image_embeds.pos,
                    max_new_tokens,
                ):
                    yield token

            answer = "".join(list(generator()))

        return [answer]

    def get_input_embeddings(self) -> nn.Embedding:
        """
        Lazily wrap the raw parameter `self.model.text.wte` in a real
        `nn.Embedding` layer so that HF mix-ins recognise it.  The wrapper
        **shares** the weight tensor—no copy is made.
        """
        if not hasattr(self, "_input_embeddings"):
            self._input_embeddings = nn.Embedding.from_pretrained(
                self.model.text.wte,  # tensor created in text.py
                freeze=True,  # set to False if you need it trainable
            )
        return self._input_embeddings

    def set_input_embeddings(self, value: Union[nn.Embedding, nn.Module]) -> None:
        """
        Lets HF functions (e.g. `resize_token_embeddings`) replace or resize the
        embeddings and keeps everything tied to `self.model.text.wte`.
        """
        # 1. point the low-level parameter to the new weight matrix
        self.model.text.wte = value.weight
        # 2. keep a reference for get_input_embeddings()
        self._input_embeddings = value

    def input_embeds(
        self,
        input_ids: Union[torch.LongTensor, list, tuple],
        *,
        device: torch.device | None = None
    ) -> torch.FloatTensor:
        """
        Back-compat wrapper that turns token IDs into embeddings.

        Example:
            ids = torch.tensor([[1, 2, 3]])
            embeds = model.input_embeds(ids)      # (1, 3, hidden_dim)
        """
        if not torch.is_tensor(input_ids):
            input_ids = torch.as_tensor(input_ids)
        if device is not None:
            input_ids = input_ids.to(device)

        return self.get_input_embeddings()(input_ids)