Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import TextStreamer | |
| import webcolors | |
| import os | |
| import random | |
| from collections import Counter | |
| import numpy as np | |
| from torchvision import transforms | |
| from .magic_utils import get_colored_contour, find_different_colors, get_bounding_box_from_mask | |
| from .LLaVA.llava.conversation import conv_templates, SeparatorStyle | |
| from .LLaVA.llava.model.builder import load_pretrained_model | |
| from .LLaVA.llava.mm_utils import get_model_name_from_path, expand2square, tokenizer_image_token | |
| from .LLaVA.llava.constants import ( | |
| IMAGE_TOKEN_INDEX, | |
| DEFAULT_IMAGE_TOKEN, | |
| DEFAULT_IM_START_TOKEN, | |
| DEFAULT_IM_END_TOKEN, | |
| IMAGE_PLACEHOLDER, | |
| ) | |
| import re | |
| class LLaVAModel: | |
| def __init__(self): | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| model_path = os.path.join(current_dir, "../models/llava-v1.5-7b-finetune-clean") | |
| self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( | |
| model_path=model_path, | |
| model_base=None, | |
| model_name=get_model_name_from_path(model_path), | |
| load_4bit=True | |
| ) | |
| def generate_description(self, images, question): | |
| qs = question | |
| image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN | |
| if IMAGE_PLACEHOLDER in qs: | |
| if self.model.config.mm_use_im_start_end: | |
| qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) | |
| else: | |
| qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) | |
| else: | |
| if self.model.config.mm_use_im_start_end: | |
| qs = image_token_se + "\n" + qs | |
| else: | |
| qs = DEFAULT_IMAGE_TOKEN + "\n" + qs | |
| images_tensor = [] | |
| image_sizes = [] | |
| to_pil = transforms.ToPILImage() | |
| for image in images: | |
| image = image.clone().permute(2, 0, 1).cpu() | |
| image = to_pil(image) | |
| image_sizes.append(image.size) | |
| image = expand2square(image, tuple(int(x) for x in self.image_processor.image_mean)) | |
| image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] | |
| images_tensor.append(image.half()) | |
| conv = conv_templates["llava_v1"].copy() | |
| conv.append_message(conv.roles[0], qs) | |
| conv.append_message(conv.roles[1], None) | |
| prompt = conv.get_prompt() | |
| input_ids = ( | |
| tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") | |
| .unsqueeze(0) | |
| .cuda() | |
| ) | |
| with torch.inference_mode(): | |
| output_ids = self.model.generate( | |
| input_ids, | |
| images=images_tensor, | |
| image_sizes=image_sizes, | |
| temperature=0.2, | |
| do_sample=True, | |
| use_cache=True, | |
| ) | |
| outputs = self.tokenizer.decode(output_ids[0]).strip() | |
| outputs = outputs.split('>')[1].split('<')[0] | |
| # print(outputs) | |
| return outputs | |
| def process(self, image, colored_image, add_mask): | |
| description = "" | |
| answer1 = "" | |
| answer2 = "" | |
| image_with_sketch = image.clone() | |
| if torch.sum(add_mask).item() > 0: | |
| x_min, y_min, x_max, y_max = get_bounding_box_from_mask(add_mask) | |
| # print(x_min, y_min, x_max, y_max) | |
| question = f"This is an 'I draw, you guess' game. I will upload an image containing some sketches. To help you locate the sketch, I will give you the normalized bounding box coordinates of the sketch where their original coordinates are divided by the image width and height. The top-left corner of the bounding box is at ({x_min}, {y_min}), and the bottom-right corner is at ({x_max}, {y_max}). Now tell me, what am I trying to draw with these sketches in the image?" | |
| # image_with_sketch[add_mask > 0.5] = 1.0 | |
| bool_add_mask = add_mask > 0.5 | |
| mean_brightness = image_with_sketch[bool_add_mask].mean() | |
| if mean_brightness > 0.8: | |
| image_with_sketch[bool_add_mask] = 0.0 | |
| else: | |
| image_with_sketch[bool_add_mask] = 1.0 | |
| answer1 = self.generate_description([image_with_sketch.squeeze() * 255], question) | |
| print(answer1) | |
| if not torch.equal(image, colored_image): | |
| color = find_different_colors(image.squeeze() * 255, colored_image.squeeze() * 255) | |
| image_with_bbox, colored_mask = get_colored_contour(colored_image.squeeze() * 255, image.squeeze() * 255) | |
| x_min, y_min, x_max, y_max = get_bounding_box_from_mask(colored_mask) | |
| question = f"The user will upload an image containing some contours in red color. To help you locate the contour, I will give you the normalized bounding box coordinates where their original coordinates are divided by the image width and height. The top-left corner of the bounding box is at ({x_min}, {y_min}), and the bottom-right corner is at ({x_max}, {y_max}). You need to identify what is inside the contours using a single word or phrase." | |
| answer2 = color + ', ' + self.generate_description([image_with_bbox.squeeze() * 255], question) | |
| print(answer2) | |
| return (description, answer1, answer2) |