File size: 5,534 Bytes
235555c 05d640e 235555c 05d640e 8ef2cad 05d640e 8ef2cad 05d640e 8ef2cad 05d640e 8ef2cad 05d640e 8ef2cad 05d640e 8ef2cad 05d640e 8ef2cad 05d640e 8ef2cad 05d640e 8ef2cad 05d640e 8ef2cad 05d640e 235555c 8ef2cad 05d640e 8ef2cad 05d640e 8ef2cad 05d640e 235555c 05d640e 235555c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from typing import Union
from .config import MoondreamConfig
from .moondream import MoondreamModel
# Files sometimes don't get loaded without these...
from .image_crops import *
from .vision import *
from .text import *
from .region import *
from .utils import *
def extract_question(text):
prefix = "<image>\n\nQuestion: "
suffix = "\n\nAnswer:"
if text.startswith(prefix) and text.endswith(suffix):
return text[len(prefix) : -len(suffix)]
else:
return None
class HfConfig(PretrainedConfig):
_auto_class = "AutoConfig"
model_type = "moondream1"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.config = {}
class HfMoondream(PreTrainedModel):
_auto_class = "AutoModelForCausalLM"
config_class = HfConfig
def __init__(self, config):
super().__init__(config)
self.model = MoondreamModel(
MoondreamConfig.from_dict(config.config), setup_caches=False
)
self._is_kv_cache_setup = False
def _setup_caches(self):
if not self._is_kv_cache_setup:
self.model._setup_caches()
self._is_kv_cache_setup = True
@property
def encode_image(self):
self._setup_caches()
return self.model.encode_image
@property
def query(self):
self._setup_caches()
return self.model.query
@property
def caption(self):
self._setup_caches()
return self.model.caption
@property
def detect(self):
self._setup_caches()
return self.model.detect
@property
def point(self):
self._setup_caches()
return self.model.point
@property
def detect_gaze(self):
self._setup_caches()
return self.model.detect_gaze
def answer_question(
self,
image_embeds,
question,
tokenizer=None,
chat_history="",
result_queue=None,
max_new_tokens=256,
**kwargs
):
answer = self.query(image_embeds, question)["answer"].strip()
if result_queue is not None:
result_queue.put(answer)
return answer
def batch_answer(self, images, prompts, tokenizer=None, **kwargs):
answers = []
for image, prompt in zip(images, prompts):
answers.append(self.query(image, prompt)["answer"].strip())
return answers
def _unsupported_exception(self):
raise NotImplementedError(
"This method is not supported in the latest version of moondream. "
"Consider upgrading to the updated API spec, or alternately pin "
"to 'revision=2024-08-26'."
)
def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128, **kwargs):
"""
Function definition remains unchanged for backwards compatibility.
Be aware that tokenizer, max_new_takens, and kwargs are ignored.
"""
prompt_extracted = extract_question(prompt)
if prompt_extracted is not None:
answer = self.model.query(
image=image_embeds, question=prompt_extracted, stream=False
)["answer"]
else:
image_embeds = self.encode_image(image_embeds)
prompt_tokens = torch.tensor(
[self.model.tokenizer.encode(prompt).ids],
device=self.device,
)
def generator():
for token in self.model._generate_answer(
prompt_tokens,
image_embeds.kv_cache,
image_embeds.pos,
max_new_tokens,
):
yield token
answer = "".join(list(generator()))
return [answer]
def get_input_embeddings(self) -> nn.Embedding:
"""
Lazily wrap the raw parameter `self.model.text.wte` in a real
`nn.Embedding` layer so that HF mix-ins recognise it. The wrapper
**shares** the weight tensor—no copy is made.
"""
if not hasattr(self, "_input_embeddings"):
self._input_embeddings = nn.Embedding.from_pretrained(
self.model.text.wte, # tensor created in text.py
freeze=True, # set to False if you need it trainable
)
return self._input_embeddings
def set_input_embeddings(self, value: Union[nn.Embedding, nn.Module]) -> None:
"""
Lets HF functions (e.g. `resize_token_embeddings`) replace or resize the
embeddings and keeps everything tied to `self.model.text.wte`.
"""
# 1. point the low-level parameter to the new weight matrix
self.model.text.wte = value.weight
# 2. keep a reference for get_input_embeddings()
self._input_embeddings = value
def input_embeds(
self,
input_ids: Union[torch.LongTensor, list, tuple],
*,
device: torch.device | None = None
) -> torch.FloatTensor:
"""
Back-compat wrapper that turns token IDs into embeddings.
Example:
ids = torch.tensor([[1, 2, 3]])
embeds = model.input_embeds(ids) # (1, 3, hidden_dim)
"""
if not torch.is_tensor(input_ids):
input_ids = torch.as_tensor(input_ids)
if device is not None:
input_ids = input_ids.to(device)
return self.get_input_embeddings()(input_ids)
|