Spaces:
Paused
Paused
""" | |
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
SPDX-License-Identifier: MIT | |
""" | |
import os | |
import warnings | |
from collections import OrderedDict | |
from omegaconf import ListConfig | |
warnings.filterwarnings("ignore", category=UserWarning) | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python") | |
import torch | |
from PIL import Image | |
from transformers import PreTrainedTokenizerFast | |
from utils.model import DonutConfig, DonutModel, SwinEncoder | |
from utils.processor import DolphinProcessor | |
def try_rename_lagacy_weights(ckpt, output_path=""): | |
if "state_dict" in ckpt.keys(): | |
ckpt = ckpt["state_dict"] | |
if "module" in ckpt.keys(): | |
ckpt = ckpt["module"] | |
new_ckpt = OrderedDict() | |
for k, v in ckpt.items(): | |
if k.startswith("model."): | |
k = k[len("model.") :] | |
if k.startswith("encoder"): | |
new_ckpt["vpm" + k[len("encoder") :]] = v | |
elif k.startswith("decoder"): | |
new_ckpt["llm" + k[len("encoder") :]] = v | |
else: | |
new_ckpt[k] = v | |
if output_path: | |
torch.save(new_ckpt, output_path) | |
return new_ckpt | |
def convert_listconfig_to_list(config): | |
new_config = {} | |
for k, v in config.items(): | |
if isinstance(v, ListConfig): | |
new_config[k] = list(v) | |
else: | |
new_config[k] = v | |
return new_config | |
class DOLPHIN: | |
def __init__(self, config, ckpt_path="") -> None: | |
self.model_args = config.model | |
self.swin_args = config.model.pop("swin_args") | |
self.swin_args = convert_listconfig_to_list(self.swin_args) | |
vision_tower = SwinEncoder( | |
input_size=self.swin_args["img_size"], | |
patch_size=self.swin_args["patch_size"], | |
embed_dim=self.swin_args["embed_dim"], | |
window_size=self.swin_args["window_size"], | |
encoder_layer=self.swin_args["encoder_layer"], | |
num_heads=self.swin_args["num_heads"], | |
align_long_axis=self.swin_args["align_long_axis"], | |
) | |
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=self.model_args.tokenizer_path) | |
self.tokenizer.pad_token = "<pad>" | |
self.tokenizer.bos_token = "<s>" | |
self.tokenizer.eos_token = "</s>" | |
self.tokenizer.unk_token = "<unk>" | |
if self.model_args.get("extra_answer_tokens", False): | |
# print("Allowing multitask training: adding <Answer/> to the tokenizer.") | |
prompt_end_token = " <Answer/>" | |
self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set([prompt_end_token]))}) | |
self.tokenizer._prompt_end_token = prompt_end_token | |
self.tokenizer._prompt_end_token_id = self.tokenizer.convert_tokens_to_ids(prompt_end_token) | |
donut_config = DonutConfig( | |
decoder_layer=self.model_args.decoder_layer, | |
max_length=self.model_args.max_length, | |
max_position_embeddings=self.model_args.max_position_embeddings, | |
hidden_dimension=self.model_args.hidden_dimension, | |
) | |
self.model = DonutModel(config=donut_config, vision_tower=vision_tower, tokenizer=self.tokenizer) | |
if self.model_args.model_name_or_path: | |
ckpt = torch.load(self.model_args.model_name_or_path) | |
ckpt = try_rename_lagacy_weights(ckpt) | |
self.model.load_state_dict(ckpt, strict=True) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model.to(device) | |
self.model.eval() | |
transform_args = { | |
"input_size": self.swin_args["img_size"], | |
"max_length": self.model_args.max_length, | |
} | |
self.processor = DolphinProcessor({}, self.tokenizer, transform_args=transform_args) | |
def chat( | |
self, | |
question, | |
image, | |
return_raw=False, | |
return_score=False, | |
return_img_size=False, | |
only_return_img_size=False, | |
max_batch_size=16, | |
): | |
def _preprocess_image(image): | |
if isinstance(image, str): | |
image = Image.open(image).convert("RGB") | |
if return_img_size or only_return_img_size: | |
image_tensor, ori_size = self.processor.process_image_for_inference(image, return_img_size=True) | |
else: | |
image_tensor = self.processor.process_image_for_inference(image, return_img_size=False) | |
ori_size = None | |
return image_tensor, ori_size | |
def _preprocess_prompt(question): | |
if self.model_args.get("extra_answer_tokens", False): | |
if self.tokenizer._prompt_end_token not in question: | |
question = question + self.tokenizer._prompt_end_token | |
prompt_ids = self.processor.process_prompt_for_inference(question) | |
return prompt_ids | |
def _preprocess_prompt_batch(question): | |
if self.model_args.get("extra_answer_tokens", False): | |
for i in range(len(question)): | |
if self.tokenizer._prompt_end_token not in question[i]: | |
question[i] = question[i] + self.tokenizer._prompt_end_token | |
if not question[i].startswith("<s>"): | |
question[i] = "<s>" + question[i] | |
return question | |
def _postprocess(output, question): | |
output = output.replace("<s>", "").replace(question, "").replace("</s>", "").replace("<pad>", "") | |
if self.model_args.get("extra_answer_tokens", False): | |
output = output.split(self.tokenizer._prompt_end_token)[-1] | |
return output | |
if isinstance(question, list): | |
image_tensor_list = [] | |
for i in image: | |
image_tensor, ori_size = _preprocess_image(i) | |
image_tensor_list.append(image_tensor) | |
image_tensor = torch.cat(image_tensor_list, dim=0) | |
question = _preprocess_prompt_batch(question) | |
self.processor.tokenizer.padding_side = "left" | |
prompt_ids = self.processor.tokenizer( | |
question, add_special_tokens=False, return_tensors="pt", padding=True | |
).input_ids | |
else: | |
image_tensor, ori_size = _preprocess_image(image) | |
prompt_ids = _preprocess_prompt(question) | |
if only_return_img_size: | |
return ori_size | |
model_output_batch = [] | |
for i in range(0, image_tensor.shape[0], max_batch_size): | |
image_tensor_batch = image_tensor[i : i + max_batch_size] | |
prompt_ids_batch = prompt_ids[i : i + max_batch_size] | |
model_output = self.model.inference(image_tensors=image_tensor_batch, prompt_ids=prompt_ids_batch) | |
model_output_batch.append(model_output) | |
model_output = {} | |
for k, v in model_output_batch[0].items(): | |
if isinstance(v, torch.Tensor): | |
model_output[k] = sum( | |
[v_batch[k].cpu().numpy().tolist() for v_batch in model_output_batch], | |
[], | |
) | |
else: | |
model_output[k] = sum([v_batch[k] for v_batch in model_output_batch], []) | |
if return_raw: | |
if return_img_size: | |
return model_output, ori_size | |
return model_output | |
else: | |
if isinstance(question, list): | |
output = [_postprocess(model_output["repetitions"][i], question[i]) for i in range(len(question))] | |
score = model_output["scores"] | |
else: | |
output = _postprocess(model_output["repetitions"][0], question) | |
score = model_output["scores"][0] | |
if return_score: | |
return output, score | |
if return_img_size: | |
return output, ori_size | |
return output | |