import argparse from transformers import AutoTokenizer, AutoModelForCausalLM import torch import os from llava.conversation import conv_templates, SeparatorStyle from llava.utils import disable_torch_init from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria from llava.model import * from llava.model.utils import KeywordsStoppingCriteria from PIL import Image import os import requests from PIL import Image from io import BytesIO import glob import numpy as np import json import tqdm DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" def load_image(image_file): if image_file.startswith('http') or image_file.startswith('https'): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = Image.open(image_file).convert('RGB') return image classes = ['wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed', 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 'clock', 'flag'] def eval_model(args): # Model disable_torch_init() model_name = os.path.expanduser(args.model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) if "mpt" in model_name.lower(): model = LlavaMPTForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda() else: # model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda() model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')#.cuda() image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16) mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) vision_tower = model.get_model().vision_tower[0] if vision_tower.device.type == 'meta': vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda() model.get_model().vision_tower[0] = vision_tower else: vision_tower.to(device='cuda', dtype=torch.float16) vision_config = vision_tower.config vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] vision_config.use_im_start_end = mm_use_im_start_end if mm_use_im_start_end: vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 # paths for all images images = sorted(glob.glob("/mnt/proj74/xinlai/dataset/ade20k/images/training/*.jpg")) results = [] for i, image_file in enumerate(tqdm.tqdm(images)): # if i == 2: # break # if i % 100 == 0: # print("i: {}, len(images): {}".format(i, len(images))) print("i: {}, len(images): {}".format(i, len(images))) image = load_image(image_file) image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] image_tensor = image_tensor.unsqueeze(0).half().cuda() label_file = image_file.replace("images", "annotations").replace(".jpg", ".png") label = Image.open(label_file) label = np.array(label) label_unique = np.unique(label) for label in label_unique: if label == 0: continue class_id = label - 1 class_label = classes[class_id] input_conv = "Can you describe the {} in this image?".format(class_label) qs = input_conv # qs = args.query if mm_use_im_start_end: qs = qs + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN else: qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len if "v1" in model_name.lower(): conv_mode = "llava_v1" elif "mpt" in model_name.lower(): conv_mode = "mpt_multimodal" else: conv_mode = "multimodal" if args.conv_mode is not None and conv_mode != args.conv_mode: print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) else: args.conv_mode = conv_mode conv = conv_templates[args.conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() inputs = tokenizer([prompt]) # image = load_image(args.image_file) # image = load_image(image_file) # image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] input_ids = torch.as_tensor(inputs.input_ids).cuda() stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor, do_sample=True, temperature=0.2, max_new_tokens=512, #1024, stopping_criteria=[stopping_criteria]) input_token_len = input_ids.shape[1] n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() if n_diff_input_output > 0: print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] outputs = outputs.strip() if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() print("qs: {}, output: {}, image_file: {}".format(qs, outputs, image_file)) results.append({'image_id': image_file.split("/")[-1], 'input': input_conv, 'output': outputs}) with open("/mnt/proj74/xinlai/LLM/LLaVA/ade20k_conversations.json", "w+") as f: json.dump(results, f) # print(outputs) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-name", type=str, default="facebook/opt-350m") parser.add_argument("--image-file", type=str, required=True) parser.add_argument("--query", type=str, required=True) parser.add_argument("--conv-mode", type=str, default=None) args = parser.parse_args() eval_model(args)