|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
from PIL import Image |
|
from .model.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX |
|
from .model.conversation import SeparatorStyle, conv_templates |
|
from .model.mm_utils import KeywordsStoppingCriteria, process_image, tokenizer_image_token |
|
from .model import get_model_name_from_path, load_pretrained_model |
|
from transformers import TextIteratorStreamer |
|
from threading import Thread |
|
|
|
class DescribeAnythingModel(nn.Module): |
|
def __init__(self, model_path, conv_mode, prompt_mode, temperature, top_p, num_beams, max_new_tokens, **kwargs): |
|
super().__init__() |
|
|
|
self.model_path = model_path |
|
self.conv_mode = conv_mode |
|
self.prompt_mode = prompt_mode |
|
self.temperature = temperature |
|
self.top_p = top_p |
|
self.num_beams = num_beams |
|
self.max_new_tokens = max_new_tokens |
|
|
|
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, None, **kwargs) |
|
model.config.image_processor = image_processor |
|
|
|
self.tokenizer = tokenizer |
|
self.model = model |
|
self.context_len = context_len |
|
|
|
self.model_name = get_model_name_from_path(model_path) |
|
|
|
def get_prompt(self, qs): |
|
if DEFAULT_IMAGE_TOKEN not in qs: |
|
raise ValueError("no <image> tag found in input.") |
|
|
|
conv = conv_templates[self.conv_mode].copy() |
|
conv.append_message(conv.roles[0], qs) |
|
conv.append_message(conv.roles[1], None) |
|
prompt = conv.get_prompt() |
|
|
|
return prompt, conv |
|
|
|
@staticmethod |
|
def mask_to_box(mask_np): |
|
mask_coords = np.argwhere(mask_np) |
|
y0, x0 = mask_coords.min(axis=0) |
|
y1, x1 = mask_coords.max(axis=0) + 1 |
|
|
|
h = y1 - y0 |
|
w = x1 - x0 |
|
|
|
return x0, y0, w, h |
|
|
|
@classmethod |
|
def crop_image(cls, pil_img, mask_np, crop_mode, min_box_w=48, min_box_h=48): |
|
if crop_mode == "full": |
|
|
|
info = dict(mask_np=mask_np) |
|
return pil_img, info |
|
|
|
if crop_mode == "crop": |
|
|
|
x0, y0, w, h = cls.mask_to_box(mask_np) |
|
img_np = np.asarray(pil_img) |
|
assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}" |
|
cropped_mask_np = mask_np[y0:y0+h, x0:x0+w] |
|
cropped_img_np = img_np[y0:y0+h, x0:x0+w] |
|
cropped_pil_img = Image.fromarray(cropped_img_np) |
|
elif crop_mode == "context_crop": |
|
|
|
x0, y0, w, h = cls.mask_to_box(mask_np) |
|
img_np = np.asarray(pil_img) |
|
assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}" |
|
img_h, img_w = img_np.shape[:2] |
|
cropped_mask_np = mask_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)] |
|
cropped_img_np = img_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)] |
|
cropped_pil_img = Image.fromarray(cropped_img_np) |
|
elif crop_mode == "focal_crop": |
|
|
|
x0, y0, w, h = cls.mask_to_box(mask_np) |
|
img_np = np.asarray(pil_img) |
|
assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}" |
|
img_h, img_w = img_np.shape[:2] |
|
|
|
xc, yc = x0 + w/2, y0 + h/2 |
|
|
|
w, h = max(w, min_box_w), max(h, min_box_h) |
|
x0, y0 = int(xc - w / 2), int(yc - h / 2) |
|
|
|
cropped_mask_np = mask_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)] |
|
cropped_img_np = img_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)] |
|
cropped_pil_img = Image.fromarray(cropped_img_np) |
|
elif crop_mode == "crop_mask": |
|
|
|
x0, y0, w, h = cls.mask_to_box(mask_np) |
|
img_np = np.asarray(pil_img) |
|
assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}" |
|
cropped_mask_np = mask_np[y0:y0+h, x0:x0+w] |
|
cropped_img_np = img_np[y0:y0+h, x0:x0+w] |
|
|
|
cropped_img_np = cropped_img_np * cropped_mask_np[..., None] |
|
cropped_pil_img = Image.fromarray(cropped_img_np) |
|
else: |
|
raise ValueError(f"Unsupported crop_mode: {crop_mode}") |
|
|
|
info = dict(mask_np=cropped_mask_np) |
|
return cropped_pil_img, info |
|
|
|
def get_description(self, image_pil, mask_pil, query, streaming=False): |
|
prompt, conv = self.get_prompt(query) |
|
if not isinstance(image_pil, (list, tuple)): |
|
assert not isinstance(mask_pil, (list, tuple)), "image_pil and mask_pil must be both list or tuple or not list or tuple." |
|
image_pils = [image_pil] |
|
mask_pils = [mask_pil] |
|
else: |
|
image_pils = image_pil |
|
mask_pils = mask_pil |
|
description = self.get_description_from_prompt(image_pils, mask_pils, prompt, conv, streaming=streaming) |
|
|
|
return description |
|
|
|
def get_image_tensor(self, image_pil, mask_pil, crop_mode, crop_mode2): |
|
|
|
mask_np = (np.asarray(mask_pil) > 0).astype(np.uint8) |
|
images_tensor, image_info = process_image(image_pil, self.model.config, None, pil_preprocess_fn=lambda pil_img: self.crop_image(image_pil, mask_np=mask_np, crop_mode=crop_mode)) |
|
images_tensor = images_tensor[None].to(self.model.device, dtype=torch.float16) |
|
|
|
mask_np = image_info["mask_np"] |
|
mask_pil = Image.fromarray(mask_np * 255) |
|
|
|
masks_tensor = process_image(mask_pil, self.model.config, None) |
|
masks_tensor = masks_tensor[None].to(self.model.device, dtype=torch.float16) |
|
|
|
images_tensor = torch.cat((images_tensor, masks_tensor[:, :1, ...]), dim=1) |
|
|
|
if crop_mode2 is not None: |
|
images_tensor2, image_info2 = process_image(image_pil, self.model.config, None, pil_preprocess_fn=lambda pil_img: self.crop_image(pil_img, mask_np=mask_np, crop_mode=crop_mode2)) |
|
images_tensor2 = images_tensor2[None].to(self.model.device, dtype=torch.float16) |
|
|
|
mask_np2 = image_info2["mask_np"] |
|
mask_pil2 = Image.fromarray(mask_np2 * 255) |
|
|
|
masks_tensor2 = process_image(mask_pil2, self.model.config, None) |
|
masks_tensor2 = masks_tensor2[None].to(self.model.device, dtype=torch.float16) |
|
|
|
images_tensor2 = torch.cat((images_tensor2, masks_tensor2[:, :1, ...]), dim=1) |
|
else: |
|
images_tensor2 = None |
|
|
|
return torch.cat((images_tensor, images_tensor2), dim=1) if images_tensor2 is not None else images_tensor |
|
|
|
def get_description_from_prompt(self, image_pils, mask_pils, prompt, conv, streaming=False): |
|
if streaming: |
|
return self.get_description_from_prompt_iterator(image_pils, mask_pils, prompt, conv, streaming=True) |
|
else: |
|
|
|
output = self.get_description_from_prompt_iterator(image_pils, mask_pils, prompt, conv, streaming=False) |
|
return next(output) |
|
|
|
def get_description_from_prompt_iterator(self, image_pils, mask_pils, prompt, conv, streaming=False): |
|
crop_mode, crop_mode2 = self.prompt_mode.split("+") |
|
assert crop_mode == "full", "Current prompt only supports first crop as full (non-cropped). If you need other specifications, please update the prompt." |
|
|
|
assert len(image_pils) == len(mask_pils), f"image_pils and mask_pils must have the same length. Got {len(image_pils)} and {len(mask_pils)}." |
|
image_tensors = [self.get_image_tensor(image_pil, mask_pil, crop_mode=crop_mode, crop_mode2=crop_mode2) for image_pil, mask_pil in zip(image_pils, mask_pils)] |
|
|
|
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda() |
|
|
|
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 |
|
keywords = [stop_str] |
|
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) |
|
|
|
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) if streaming else None |
|
generation_kwargs = dict( |
|
input_ids=input_ids, |
|
images=image_tensors, |
|
do_sample=True if self.temperature > 0 else False, |
|
temperature=self.temperature, |
|
top_p=self.top_p, |
|
num_beams=self.num_beams, |
|
max_new_tokens=self.max_new_tokens, |
|
use_cache=True, |
|
stopping_criteria=[stopping_criteria], |
|
streamer=streamer |
|
) |
|
|
|
|
|
if streaming: |
|
thread = Thread(target=self.model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
generated_text = "" |
|
for new_text in streamer: |
|
generated_text += new_text |
|
if stop_str in generated_text: |
|
generated_text = generated_text[:generated_text.find(stop_str)] |
|
break |
|
yield new_text |
|
|
|
thread.join() |
|
else: |
|
with torch.inference_mode(): |
|
output_ids = self.model.generate(**generation_kwargs) |
|
|
|
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] |
|
outputs = outputs.strip() |
|
if outputs.endswith(stop_str): |
|
outputs = outputs[: -len(stop_str)] |
|
outputs = outputs.strip() |
|
|
|
yield outputs |
|
|