Spaces:
Runtime error
Runtime error
| import argparse | |
| import json | |
| from math import ceil | |
| import os | |
| import random | |
| import uuid | |
| from collections import defaultdict | |
| from typing import Callable | |
| import time | |
| import cv2 | |
| import webdataset as wds | |
| from sklearn.metrics import recall_score, average_precision_score | |
| import more_itertools | |
| import numpy as np | |
| import torch | |
| from coco_metric import compute_cider, postprocess_captioning_generation | |
| from eval_datasets import VQADataset, GQADataset | |
| from tqdm import tqdm | |
| from collections import Counter | |
| from vqa_metric import compute_vqa_accuracy, compute_gqa_accuracy | |
| from open_flamingo.eval.classification import ( | |
| compute_per_sample_probs, | |
| compute_per_sample_loss, | |
| ) | |
| from open_flamingo.eval.imagenet_utils import ( | |
| openai_imagenet_classnames, | |
| IMAGENET_1K_CLASS_ID_TO_LABEL, | |
| ) | |
| from open_flamingo.src.factory import create_model_and_transforms | |
| from PIL import Image | |
| from io import BytesIO | |
| import base64 | |
| from open_flamingo.train.distributed import init_distributed_device, world_info_from_env | |
| import string | |
| from lavis.datasets.builders import load_dataset | |
| def get_iou(box1, box2): | |
| # box1 and box2 should be in the format [x1, y1, x2, y2] | |
| intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \ | |
| max(0, min(box1[3], box2[3]) - max(box1[1], box2[1])) | |
| area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) | |
| area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) | |
| union = area_box1 + area_box2 - intersection | |
| iou = intersection / union if union > 0 else 0 | |
| return iou | |
| def expand2square(pil_img, background_color): | |
| width, height = pil_img.size | |
| if width == height: | |
| return pil_img | |
| elif width > height: | |
| result = Image.new(pil_img.mode, (width, width), background_color) | |
| result.paste(pil_img, (0, (width - height) // 2)) | |
| return result | |
| else: | |
| result = Image.new(pil_img.mode, (height, height), background_color) | |
| result.paste(pil_img, ((height - width) // 2, 0)) | |
| return result | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--lm_path", type=str, default="facebook/opt-1.3b") | |
| parser.add_argument("--lm_tokenizer_path", type=str, default="facebook/opt-30b") | |
| parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str) | |
| parser.add_argument("--vision_encoder_pretrained", default="openai", type=str) | |
| parser.add_argument("--checkpoint_path", type=str, required=True) | |
| parser.add_argument( | |
| "--results_file", type=str, default=None, help="JSON file to save results" | |
| ) | |
| # Trial arguments | |
| parser.add_argument("--shots", nargs="+", default=[0, 4, 8, 16, 32], type=int) | |
| parser.add_argument( | |
| "--num_trials", | |
| type=int, | |
| default=1, | |
| help="Number of trials to run for each shot using different demonstrations", | |
| ) | |
| parser.add_argument( | |
| "--trial_seeds", | |
| nargs="+", | |
| default=[0], | |
| help="Seeds to use for each trial for picking demonstrations and eval sets", | |
| ) | |
| parser.add_argument( | |
| "--num_samples", type=int, default=5000, help="Number of samples to evaluate on" | |
| ) | |
| parser.add_argument("--batch_size", type=int, default=8) | |
| # Per-dataset evaluation flags | |
| parser.add_argument( | |
| "--eval_coco", | |
| action="store_true", | |
| default=False, | |
| help="Whether to evaluate on COCO.", | |
| ) | |
| parser.add_argument( | |
| "--eval_vqav2", | |
| action="store_true", | |
| default=False, | |
| help="Whether to evaluate on VQAV2.", | |
| ) | |
| parser.add_argument( | |
| "--eval_ok_vqa", | |
| action="store_true", | |
| default=False, | |
| help="Whether to evaluate on OK-VQA.", | |
| ) | |
| parser.add_argument( | |
| "--eval_imagenet", | |
| action="store_true", | |
| default=False, | |
| help="Whether to evaluate on ImageNet.", | |
| ) | |
| parser.add_argument( | |
| "--eval_flickr30", | |
| action="store_true", | |
| default=False, | |
| help="Whether to evaluate on Flickr30.", | |
| ) | |
| parser.add_argument( | |
| "--eval_refcoco", | |
| action="store_true", | |
| default=False, | |
| help="Whether to evaluate on RefCOCO.", | |
| ) | |
| # Dataset arguments | |
| ## Flickr30 Dataset | |
| parser.add_argument( | |
| "--flickr_image_dir_path", | |
| type=str, | |
| help="Path to the flickr30/flickr30k_images directory.", | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--flickr_annotations_json_path", | |
| type=str, | |
| help="Path to the dataset_flickr30k_coco_style.json file.", | |
| default=None, | |
| ) | |
| ## COCO Dataset | |
| parser.add_argument( | |
| "--coco_image_dir_path", | |
| type=str, | |
| help="Path to the flickr30/flickr30k_images directory.", | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--coco_annotations_json_path", | |
| type=str, | |
| default=None, | |
| ) | |
| ## VQAV2 Dataset | |
| parser.add_argument( | |
| "--vqav2_image_dir_path", | |
| type=str, | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--vqav2_questions_json_path", | |
| type=str, | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--vqav2_annotations_json_path", | |
| type=str, | |
| default=None, | |
| ) | |
| ## OK-VQA Dataset | |
| parser.add_argument( | |
| "--ok_vqa_image_dir_path", | |
| type=str, | |
| help="Path to the vqav2/train2014 directory.", | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--ok_vqa_questions_json_path", | |
| type=str, | |
| help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.", | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--ok_vqa_annotations_json_path", | |
| type=str, | |
| help="Path to the v2_mscoco_train2014_annotations.json file.", | |
| default=None, | |
| ) | |
| ## Imagenet dataset | |
| parser.add_argument("--imagenet_root", type=str, default="/tmp") | |
| ## RefCOCO dataset | |
| parser.add_argument("--refcoco_tsvfile", type=str, default=None) | |
| parser.add_argument( | |
| "--location_token_num", | |
| default=1000, | |
| type=int, | |
| ) | |
| # distributed training | |
| parser.add_argument( | |
| "--dist-url", | |
| default="env://", | |
| type=str, | |
| help="url used to set up distributed training", | |
| ) | |
| parser.add_argument( | |
| "--dist-backend", default="nccl", type=str, help="distributed backend" | |
| ) | |
| parser.add_argument( | |
| "--horovod", | |
| default=False, | |
| action="store_true", | |
| help="Use horovod for distributed training.", | |
| ) | |
| parser.add_argument( | |
| "--no-set-device-rank", | |
| default=False, | |
| action="store_true", | |
| help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", | |
| ) | |
| parser.add_argument( | |
| "--dist", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--lora", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--lora_r", | |
| default=16, | |
| type=int, | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--legacy", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--special", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--id", | |
| default=0, | |
| type=int, | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--eval_gqa", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--use_sam", | |
| default=None, | |
| type=str, | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--add_visual_token", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--use_format_v2", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--eval_aro", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--eval_pisc", | |
| default=False, | |
| action="store_true", | |
| ) | |
| class OKVQAPostProcess(): | |
| def __init__(self): | |
| self._lemmatizer = None | |
| def _lemmatize(self, answers): | |
| def apply(answer): | |
| doc = self.lemmatizer(answer) | |
| words = [] | |
| for token in doc: | |
| if token.pos_ in ["NOUN", "VERB"]: | |
| words.append(token.lemma_) | |
| else: | |
| words.append(token.text) | |
| answer = " ".join(words) | |
| return answer | |
| return [apply(answer) for answer in answers] | |
| def lemmatizer(self): | |
| if self._lemmatizer is None: | |
| try: | |
| import spacy | |
| self._lemmatizer = spacy.load("en_core_web_sm") | |
| except ImportError: | |
| logging.error( | |
| """ | |
| Please install spacy and en_core_web_sm model to apply lemmatization. | |
| python -m spacy download en_core_web_sm | |
| OR | |
| import spacy.cli | |
| spacy.cli.download("en_core_web_sm") | |
| """ | |
| ) | |
| exit(1) | |
| return self._lemmatizer | |
| def main(): | |
| args = parser.parse_args() | |
| if args.dist: | |
| args.local_rank, args.rank, args.world_size = world_info_from_env() | |
| print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}") | |
| device_id = init_distributed_device(args) | |
| else: | |
| args.rank = 0 | |
| args.world_size = 1 | |
| print(f"rank: {args.rank} world_size: {args.world_size}") | |
| if "sam" in args.checkpoint_path: | |
| args.use_sam = "vit_l" | |
| args.add_visual_token = True | |
| if "lora" in args.checkpoint_path: | |
| args.lora = True | |
| args.add_pe = False | |
| args.add_box = True | |
| args.relation = False | |
| args.enhance_data = False | |
| args.use_format_v2 = True | |
| import hashlib | |
| args.id = hashlib.sha224(args.checkpoint_path.encode()).hexdigest() | |
| # load model | |
| flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms( | |
| args.vision_encoder_path, | |
| args.vision_encoder_pretrained, | |
| args.lm_path, | |
| args.lm_tokenizer_path, | |
| location_token_num=args.location_token_num, | |
| lora=args.lora, | |
| lora_r=16, | |
| use_sam=args.use_sam, | |
| add_visual_token=args.add_visual_token, | |
| use_format_v2=args.use_format_v2, | |
| add_box=args.add_box, | |
| add_pe=args.add_pe, | |
| add_relation=args.relation, | |
| enhance_data=args.enhance_data, | |
| ) | |
| flamingo.use_format_v2 = args.use_format_v2 | |
| if args.special: | |
| flamingo.special = True | |
| else: | |
| flamingo.special = False | |
| if args.legacy: | |
| flamingo.legacy = True | |
| print("use legacy evaluation") | |
| flamingo.step_num = int(args.checkpoint_path.split("/")[-1].split(".")[0].split("_")[-1]) | |
| flamingo.expr_name = args.checkpoint_path.split("/")[-2] | |
| if args.rank == 0: | |
| print("legacy", True if hasattr(flamingo, "legacy") else False) | |
| print("step:", flamingo.step_num) | |
| print("expr:", flamingo.expr_name) | |
| print("use format v2:", flamingo.use_format_v2) | |
| print(args) | |
| checkpoint = torch.load(args.checkpoint_path, map_location="cpu") | |
| model_state_dict = {} | |
| for key in checkpoint["model_state_dict"].keys(): | |
| model_state_dict[key.replace("module.", "")] = checkpoint["model_state_dict"][key] | |
| if "vision_encoder.logit_scale"in model_state_dict: | |
| # previous checkpoint has some unnecessary weights | |
| del model_state_dict["vision_encoder.logit_scale"] | |
| del model_state_dict["vision_encoder.visual.proj"] | |
| del model_state_dict["vision_encoder.visual.ln_post.weight"] | |
| del model_state_dict["vision_encoder.visual.ln_post.bias"] | |
| flamingo.load_state_dict(model_state_dict, strict=True) | |
| results = defaultdict(list) | |
| if args.eval_coco: | |
| print("Evaluating on COCO...") | |
| for shot in args.shots: | |
| scores = [] | |
| for seed, trial in zip(args.trial_seeds, range(args.num_trials)): | |
| cider_score = evaluate_coco_flickr( | |
| model=flamingo, | |
| tokenizer=tokenizer, | |
| image_processor=image_processor, | |
| batch_size=args.batch_size, | |
| image_dir_path=args.coco_image_dir_path, | |
| annotations_json_path=args.coco_annotations_json_path, | |
| device=args.device, | |
| seed=seed, | |
| vis_embed_size=vis_embed_size, | |
| rank=args.rank, | |
| world_size=args.world_size, | |
| id=args.id, | |
| ) | |
| print(f"Shots {shot} Trial {trial} CIDEr score: {cider_score}") | |
| scores.append(cider_score) | |
| print(f"Shots {shot} Mean CIDEr score: {np.mean(scores)}") | |
| results["coco"].append( | |
| {"shots": shot, "trials": scores, "mean": np.mean(scores)} | |
| ) | |
| if args.eval_ok_vqa: | |
| print("Evaluating on OK-VQA...") | |
| for shot in args.shots: | |
| scores = [] | |
| for seed, trial in zip(args.trial_seeds, range(args.num_trials)): | |
| ok_vqa_score = evaluate_vqa( | |
| model=flamingo, | |
| tokenizer=tokenizer, | |
| image_processor=image_processor, | |
| batch_size=args.batch_size, | |
| image_dir_path=args.ok_vqa_image_dir_path, | |
| questions_json_path=args.ok_vqa_questions_json_path, | |
| annotations_json_path=args.ok_vqa_annotations_json_path, | |
| vqa_dataset="ok_vqa", | |
| vis_embed_size=vis_embed_size, | |
| rank=args.rank, | |
| world_size=args.world_size, | |
| id=args.id, | |
| ) | |
| results["ok_vqa"].append( | |
| {"shots": shot, "score": ok_vqa_score} | |
| ) | |
| if args.eval_vqav2: | |
| print("Evaluating on VQAv2...") | |
| for shot in args.shots: | |
| scores = [] | |
| for seed, trial in zip(args.trial_seeds, range(args.num_trials)): | |
| vqa_score = evaluate_vqa( | |
| model=flamingo, | |
| tokenizer=tokenizer, | |
| image_processor=image_processor, | |
| batch_size=args.batch_size, | |
| image_dir_path=args.vqav2_image_dir_path, | |
| questions_json_path=args.vqav2_questions_json_path, | |
| annotations_json_path=args.vqav2_annotations_json_path, | |
| vqa_dataset="vqa", | |
| vis_embed_size=vis_embed_size, | |
| rank=args.rank, | |
| world_size=args.world_size, | |
| id=args.id, | |
| ) | |
| results["vqav2"].append( | |
| {"shots": shot, "score": vqa_score} | |
| ) | |
| if args.eval_gqa: | |
| print("Evaluating on GQA...") | |
| for shot in args.shots: | |
| scores = [] | |
| for seed, trial in zip(args.trial_seeds, range(args.num_trials)): | |
| vqa_score = evaluate_vqa( | |
| model=flamingo, | |
| tokenizer=tokenizer, | |
| image_processor=image_processor, | |
| batch_size=args.batch_size, | |
| vqa_dataset="gqa", | |
| vis_embed_size=vis_embed_size, | |
| rank=args.rank, | |
| world_size=args.world_size, | |
| id=args.id, | |
| ) | |
| results["gqa"].append( | |
| {"shots": shot, "score": vqa_score} | |
| ) | |
| if args.eval_imagenet: | |
| print("Evaluating on ImageNet...") | |
| for shot in args.shots: | |
| scores = [] | |
| for seed, trial in zip(args.trial_seeds, range(args.num_trials)): | |
| imagenet_score = evaluate_imagenet( | |
| model=flamingo, | |
| tokenizer=tokenizer, | |
| image_processor=image_processor, | |
| batch_size=args.batch_size, | |
| num_samples=args.num_samples, | |
| num_shots=shot, | |
| device=args.device, | |
| seed=seed, | |
| imagenet_root=args.imagenet_root, | |
| ) | |
| print( | |
| f"Shots {shot} Trial {trial} " f"ImageNet score: {imagenet_score}" | |
| ) | |
| scores.append(imagenet_score) | |
| print(f"Shots {shot} Mean ImageNet score: {np.mean(scores)}") | |
| results["imagenet"].append( | |
| {"shots": shot, "trials": scores, "mean": np.mean(scores)} | |
| ) | |
| if args.eval_refcoco: | |
| print("Evaluating on RefCOCO...") | |
| refcoco_score = evaluate_refcoco( | |
| model=flamingo, | |
| tokenizer=tokenizer, | |
| image_processor=image_processor, | |
| batch_size=args.batch_size, | |
| device=args.device, | |
| tsvfile=args.refcoco_tsvfile, | |
| vis_embed_size=vis_embed_size, | |
| rank=args.rank, | |
| world_size=args.world_size, | |
| id=args.id, | |
| ) | |
| results["refcoco"].append( | |
| {"score": refcoco_score} | |
| ) | |
| if args.eval_aro: | |
| print("Evaluating on ARO...") | |
| _func = evaluate_aro | |
| # print("Evaluating on ARO ORI...") | |
| # _func = evaluate_aro_ori | |
| aro_score = _func( | |
| model=flamingo, | |
| tokenizer=tokenizer, | |
| image_processor=image_processor, | |
| batch_size=args.batch_size, | |
| device=args.device, | |
| tsvfile=args.refcoco_tsvfile, | |
| vis_embed_size=vis_embed_size, | |
| rank=args.rank, | |
| world_size=args.world_size, | |
| id=args.id, | |
| add_relation=args.relation, | |
| ) | |
| results["aro"].append( | |
| {"score": aro_score} | |
| ) | |
| if args.eval_pisc: | |
| print("Evaluating on ARO...") | |
| aro_score = evaluate_pisc( | |
| model=flamingo, | |
| tokenizer=tokenizer, | |
| image_processor=image_processor, | |
| batch_size=args.batch_size, | |
| device=args.device, | |
| tsvfile=args.refcoco_tsvfile, | |
| vis_embed_size=vis_embed_size, | |
| rank=args.rank, | |
| world_size=args.world_size, | |
| id=args.id, | |
| ) | |
| results["pisc"].append( | |
| {"score": aro_score} | |
| ) | |
| def prepare_batch_images(batch, image_processor): | |
| batch_images = None | |
| for b in batch: | |
| b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0) | |
| if batch_images is None: | |
| batch_images = b_image | |
| else: | |
| batch_images = torch.cat([batch_images, b_image], dim=0) | |
| return batch_images | |
| def get_outputs( | |
| model, | |
| batch_images, | |
| attention_mask, | |
| max_generation_length, | |
| min_generation_length, | |
| num_beams, | |
| length_penalty, | |
| input_ids, | |
| image_start_index_list=None, | |
| image_nums=None, | |
| bad_words_ids=None, | |
| ): | |
| with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16): | |
| outputs = model.generate( | |
| batch_images, | |
| input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=max_generation_length, | |
| min_length=min_generation_length, | |
| num_beams=num_beams, | |
| length_penalty=length_penalty, | |
| image_start_index_list=image_start_index_list, | |
| image_nums=image_nums, | |
| bad_words_ids=bad_words_ids, | |
| ) | |
| outputs = outputs[:, len(input_ids[0]) :] | |
| return outputs | |
| def evaluate_coco_flickr( | |
| model, | |
| tokenizer, | |
| image_processor, | |
| batch_size, | |
| image_dir_path, | |
| annotations_json_path, | |
| seed=42, | |
| max_generation_length=20, | |
| num_beams=1, | |
| length_penalty=-2.0, | |
| device=-1, | |
| is_flickr=False, | |
| vis_embed_size=None, | |
| rank=0, | |
| world_size=1, | |
| id=0, | |
| ): | |
| """Evaluate a model on COCO dataset. | |
| Args: | |
| model (nn.Module): model to evaluate | |
| tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model | |
| image_processor : image processor for the model | |
| batch_size (int): batch size | |
| image_dir_path (str, optional): path to the directory containing the images. | |
| annotations_json_path (str, optional): path to the json file containing the annotations. | |
| seed (int, optional): seed for random number generator. Defaults to 42. | |
| max_generation_length (int, optional): maximum length of the generated caption. Defaults to 10. | |
| num_beams (int, optional): number of beams to use for beam search. Defaults to 3. | |
| length_penalty (float, optional): length penalty for beam search. Defaults to -2.0. | |
| num_samples (int, optional): number of samples to evaluate on. Defaults to 5000. | |
| query_set_size (int, optional): number of samples to use for query set. Defaults to 2048. | |
| num_shots (int, optional): number of in-context samples to use. Defaults to 8. | |
| device (int, optional): device to use. Defaults to -1. | |
| num_workers (int, optional): number of workers to use for dataloader. Defaults to 4. | |
| is_flickr (bool): defines if that data is COCO or Flickr. Defaults to False (COCO). | |
| Returns: | |
| float: CIDEr score | |
| """ | |
| # eval_dataset = COCOFlickrDataset( | |
| # image_dir_path=image_dir_path, | |
| # annotations_path=annotations_json_path, | |
| # is_flickr=is_flickr, | |
| # ) | |
| coco_dataset = load_dataset("coco_caption") | |
| eval_dataset = coco_dataset["test"] | |
| model.eval().cuda() | |
| predictions = defaultdict() | |
| lang_encoder_name = model.lang_encoder.__class__.__name__.lower() | |
| # if "peft" in lang_encoder_name: | |
| # lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower() | |
| try: | |
| media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
| endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
| pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1] | |
| bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1] | |
| except: | |
| pass | |
| def get_prompt(sample): | |
| return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>" | |
| tokenizer.padding_side = "left" | |
| cnt = 0 | |
| if world_size > 1: | |
| torch.distributed.barrier() | |
| desc = "Running inference Flickr30" if is_flickr else "Running inference COCO" | |
| for ii, batch in enumerate(more_itertools.chunked( | |
| tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size | |
| )): | |
| if ii % world_size != rank: | |
| continue | |
| cnt += len(batch) | |
| batch_images = prepare_batch_images( | |
| batch=batch, | |
| image_processor=image_processor, | |
| ).cuda() | |
| batch_text = [get_prompt(s) for s in batch] | |
| encodings = tokenizer( | |
| batch_text, | |
| padding="longest", | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=2000, | |
| ) | |
| input_ids = encodings["input_ids"].cuda() | |
| attention_mask = encodings["attention_mask"].cuda() | |
| skip_special_tokens = False | |
| if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name: | |
| if rank == 0: | |
| tqdm.write("use legacy model") | |
| skip_special_tokens = True | |
| for i in range(len(input_ids)): | |
| media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0] | |
| endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0] | |
| input_ids[i, media_token_index - 1] = media_token_id | |
| input_ids[i, media_token_index] = pad_token_id | |
| input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id | |
| input_ids[i, endofmedia_token_index] = bos_token_id | |
| image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
| image_start_index_list = [[x] for x in image_start_index_list] | |
| image_nums = [1] * len(input_ids) | |
| if "llama" in lang_encoder_name: | |
| attention_mask[input_ids == 0] = 0 | |
| outputs = get_outputs( | |
| model=model, | |
| batch_images=batch_images, | |
| attention_mask=attention_mask, | |
| max_generation_length=30, | |
| min_generation_length=8, | |
| num_beams=5, | |
| length_penalty=0, | |
| input_ids=input_ids, | |
| image_start_index_list=image_start_index_list, | |
| image_nums=image_nums, | |
| ) | |
| new_predictions = [ | |
| postprocess_captioning_generation(out).replace('"', "") | |
| for out in tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| ] | |
| # if rank == 0: | |
| # tqdm.write(f"{batch_images.shape} {batch[0]} pred: {new_predictions[0]}") | |
| for i, sample in enumerate(batch): | |
| predictions[int(sample["image_id"])] = { | |
| "caption": new_predictions[i], | |
| } | |
| results_path = ( | |
| f"flickrresults_{lang_encoder_name}_{rank}_{id}.json" | |
| if is_flickr | |
| else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json" | |
| ) | |
| with open(results_path, "w") as f: | |
| f.write( | |
| json.dumps( | |
| [ | |
| {"image_id": k, "caption": predictions[k]["caption"]} | |
| for k in predictions | |
| ], | |
| indent=2, | |
| ) | |
| ) | |
| print("save to", results_path) | |
| del predictions | |
| time.sleep(10) | |
| if world_size > 1: | |
| torch.distributed.barrier() | |
| if rank == 0: | |
| print(f"evaluate on rank {rank}. world size is {world_size}") | |
| predictions = [] | |
| for rank_i in range(world_size): | |
| part_results_path = ( | |
| f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json" | |
| if is_flickr | |
| else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json" | |
| ) | |
| print("load", part_results_path) | |
| predictions.extend(json.load(open(part_results_path))) | |
| os.remove(part_results_path) | |
| print("num:", len(predictions)) | |
| results_path = ( | |
| f"flickrresults_{lang_encoder_name}.json" | |
| if is_flickr | |
| else f"cocoresults_{lang_encoder_name}.json" | |
| ) | |
| json.dump(predictions, open(results_path, "w"), indent=2) | |
| metrics = compute_cider( | |
| result_path=results_path, | |
| annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json", | |
| ) | |
| os.makedirs("eval_results", exist_ok=True) | |
| acc = metrics["CIDEr"] | |
| with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f: | |
| f.write(json.dumps(predictions, indent=2)) | |
| # delete the temporary file | |
| os.remove(results_path) | |
| else: | |
| metrics = {} | |
| metrics["CIDEr"] = 0.0 | |
| return metrics["CIDEr"] | |
| def evaluate_vqa( | |
| model, | |
| tokenizer, | |
| image_processor, | |
| batch_size, | |
| image_dir_path=None, | |
| questions_json_path=None, | |
| annotations_json_path=None, | |
| vqa_dataset="vqa", | |
| vis_embed_size=None, | |
| rank=0, | |
| world_size=1, | |
| id=0, | |
| ): | |
| """ | |
| Evaluate a model on VQA datasets. Currently supports VQA v2.0. | |
| Args: | |
| model (nn.Module): model to evaluate | |
| tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model | |
| image_processor : image processor for the model | |
| batch_size (int): batch size | |
| image_dir_path (str): path to image directory | |
| questions_json_path (str): path to questions json file | |
| annotations_json_path (str): path to annotations json file | |
| seed (int, optional): random seed. Defaults to 42. | |
| max_generation_length (int, optional): max generation length. Defaults to 5. | |
| num_beams (int, optional): number of beams to use for beam search. Defaults to 3. | |
| length_penalty (float, optional): length penalty for beam search. Defaults to -2.0. | |
| num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples. | |
| query_set_size (int, optional): size of the query set. Defaults to 2048. | |
| num_shots (int, optional): number of shots to use. Defaults to 8. | |
| device (int, optional): device to use. Defaults to -1 (cpu). | |
| num_workers (int, optional): number of workers to use. Defaults to 4. | |
| vqa_dataset (string): type of vqa dataset: currently supports vqa, ok_vqa. Defaults to vqa. | |
| Returns: | |
| float: accuracy score | |
| """ | |
| if world_size > 1: | |
| torch.distributed.barrier() | |
| if vqa_dataset == "gqa": | |
| eval_dataset = GQADataset() | |
| else: | |
| eval_dataset = VQADataset( | |
| image_dir_path=image_dir_path, | |
| question_path=questions_json_path, | |
| annotations_path=annotations_json_path, | |
| vqa_dataset=vqa_dataset, | |
| ) | |
| postprocessor = OKVQAPostProcess() | |
| try: | |
| media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
| endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
| pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1] | |
| bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1] | |
| except: | |
| pass | |
| def get_prompt(sample): | |
| return f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:" | |
| # return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>" | |
| model.eval().cuda() | |
| lang_encoder_name = model.lang_encoder.__class__.__name__.lower() | |
| if "peft" in lang_encoder_name: | |
| lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower() | |
| predictions = [] | |
| tokenizer.padding_side = "left" | |
| if world_size > 1: | |
| torch.distributed.barrier() | |
| this_tot = 0 | |
| for ii, batch in enumerate(more_itertools.chunked( | |
| tqdm(eval_dataset, desc="Running inference", disable=(rank != 0)), batch_size | |
| )): | |
| if ii % world_size != rank: | |
| continue | |
| batch_images = prepare_batch_images( | |
| batch=batch, | |
| image_processor=image_processor, | |
| ).cuda() | |
| batch_text = [get_prompt(s) for s in batch] | |
| encodings = tokenizer( | |
| batch_text, | |
| return_tensors="pt", | |
| padding="longest", | |
| truncation=True, | |
| max_length=2000, | |
| ) | |
| input_ids = encodings["input_ids"].cuda() | |
| attention_mask = encodings["attention_mask"].cuda() | |
| skip_special_tokens = True | |
| if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name: | |
| if rank == 0: | |
| tqdm.write("use legacy model") | |
| for i in range(len(input_ids)): | |
| media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0] | |
| endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0] | |
| input_ids[i, media_token_index - 1] = media_token_id | |
| input_ids[i, media_token_index] = pad_token_id | |
| input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id | |
| input_ids[i, endofmedia_token_index] = bos_token_id | |
| image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
| image_start_index_list = [[x] for x in image_start_index_list] | |
| image_nums = [1] * len(input_ids) | |
| if "llama" in lang_encoder_name: | |
| attention_mask[input_ids == 0] = 0 | |
| outputs = get_outputs( | |
| model=model, | |
| batch_images=batch_images, | |
| attention_mask=attention_mask, | |
| max_generation_length=10, | |
| min_generation_length=1, | |
| num_beams=5, | |
| length_penalty=0, | |
| input_ids=input_ids, | |
| image_start_index_list=image_start_index_list, | |
| image_nums=image_nums, | |
| ) | |
| # postprocess begin | |
| new_predictions = [ | |
| out.strip().lower().strip(string.punctuation+" ") for out in tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens) | |
| ] | |
| if vqa_dataset == "ok_vqa": | |
| new_predictions = postprocessor._lemmatize(new_predictions) | |
| if model.special: | |
| for i in range(len(new_predictions)): | |
| for answer, _ in Counter(batch[i]['answers']).most_common(): | |
| if answer in new_predictions[i]: | |
| new_predictions[i] = answer | |
| break | |
| if "cant" in new_predictions[i] and "no" == answer: | |
| new_predictions[i] = answer | |
| break | |
| if "can" in new_predictions[i] and "not" not in new_predictions[i] and "cant" not in new_predictions[i] and "yes" == answer: | |
| new_predictions[i] = answer | |
| break | |
| this_tot += 1 | |
| if rank == 0 and this_tot % 20 == 0: | |
| for i in range(1): | |
| tqdm.write(f"question: {batch[i]['question']}\nanswer: {batch[i]['answers']}model output: " + new_predictions[i]) | |
| predictions.extend( | |
| [ | |
| {"answer": p, "question_id": sample["question_id"], "_question": sample["question"], "answers": sample["answers"]} | |
| for p, sample in zip(new_predictions, batch) | |
| ] | |
| ) | |
| with open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json", "w") as f: | |
| f.write(json.dumps(predictions)) | |
| print("save to", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json") | |
| time.sleep(10) | |
| if world_size > 1: | |
| torch.distributed.barrier() | |
| if rank == 0: | |
| print(f"evaluate on rank {rank}. world size is {world_size}") | |
| predictions = [] | |
| for rank_i in range(world_size): | |
| print("load", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json") | |
| predictions.extend(json.load(open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json"))) | |
| os.remove(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json") | |
| print("num:", len(predictions)) | |
| # save the predictions to a temporary file | |
| random_uuid = str(uuid.uuid4()) | |
| with open(f"{vqa_dataset}results_{random_uuid}.json", "w") as f: | |
| f.write(json.dumps(predictions, indent=4)) | |
| if vqa_dataset == "gqa": | |
| acc = compute_gqa_accuracy(predictions) | |
| else: | |
| acc = compute_vqa_accuracy( | |
| f"{vqa_dataset}results_{random_uuid}.json", | |
| questions_json_path, | |
| annotations_json_path, | |
| vqa_dataset=vqa_dataset, | |
| ) | |
| print(vqa_dataset, "score:", acc, "| save to", f"{vqa_dataset}results_{random_uuid}.json") | |
| os.makedirs("eval_results", exist_ok=True) | |
| with open(os.path.join("eval_results", f"{vqa_dataset}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f: | |
| f.write(json.dumps(predictions, indent=2)) | |
| # delete the temporary file | |
| os.remove(f"{vqa_dataset}results_{random_uuid}.json") | |
| else: | |
| time.sleep(5) | |
| acc = 0.0 | |
| if world_size > 1: | |
| torch.distributed.barrier() | |
| return acc | |
| def evaluate_refcoco( | |
| model, | |
| tokenizer, | |
| image_processor, | |
| batch_size, | |
| tsvfile, | |
| max_generation_length=20, | |
| num_beams=3, | |
| length_penalty=-2.0, | |
| device=-1, | |
| vis_embed_size=None, | |
| rank=0, | |
| world_size=1, | |
| id=0, | |
| ): | |
| model.eval().cuda() | |
| loc_token_ids = [] | |
| for i in range(1000): | |
| loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1])) | |
| media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
| endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
| pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1] | |
| bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1] | |
| prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1] | |
| # all_ids = set(range(model.lang_encoder.lm_head.out_features)) | |
| # bad_words_ids = list(all_ids - set(loc_token_ids)) | |
| # bad_words_ids = [[b] for b in bad_words_ids] | |
| # min_loc_token_id = min(loc_token_ids) | |
| # max_loc_token_id = max(loc_token_ids) | |
| total = 0 | |
| correct = 0 | |
| ious = [] | |
| if "refcocog" in tsvfile: | |
| dataset_name = "refcocog" | |
| elif "refcocoplus" in tsvfile: | |
| dataset_name = "refcocoplus" | |
| else: | |
| dataset_name = "refcoco" | |
| with open(tsvfile, "r") as f: | |
| lines = f.readlines() | |
| pbar = tqdm(lines, disable=(rank != 0)) | |
| for ii, line in enumerate(pbar): | |
| if ii % world_size != rank: | |
| continue | |
| total += 1 | |
| line = line.rstrip() | |
| uniq_id, image_id, text, region_coord, image = line.split("\t") | |
| image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB") | |
| # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB") | |
| # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/cat.png").convert("RGB") | |
| # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/262148000.png") | |
| gt_box = np.array(list(map(float, region_coord.split(",")))) | |
| width = image.width | |
| height = image.height | |
| image = image.resize((224, 224)) | |
| gt_box = gt_box / np.array([width, height, width, height]) * 224 | |
| batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0) | |
| prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|>{text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"] | |
| # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>the cat<|#visual#|>"] | |
| # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"] | |
| # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>a man<|#visual#|> is doing a trick on a skateboard<|#visual#|>"] | |
| encodings = tokenizer( | |
| prompt, | |
| padding="longest", | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=2000, | |
| ) | |
| input_ids = encodings["input_ids"] | |
| attention_mask = encodings["attention_mask"] | |
| # attention_mask[input_ids == prebox_token_id] = 0 | |
| image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
| image_start_index_list = [[x] for x in image_start_index_list] | |
| image_nums = [1] * len(input_ids) | |
| vision_x = batch_images.cuda() | |
| lang_x = input_ids.cuda() | |
| attention_mask = attention_mask.cuda() | |
| model.debug_id = 0 | |
| with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16): | |
| outputs = model( | |
| vision_x=vision_x, | |
| lang_x=lang_x, | |
| attention_mask=attention_mask, | |
| labels=None, | |
| image_nums=image_nums, | |
| image_start_index_list=image_start_index_list, | |
| added_bbox_list=None, | |
| add_box=False, | |
| ) | |
| boxes = outputs["boxes"] | |
| scores = outputs["scores"] | |
| if len(scores) > 0: | |
| box = boxes[scores.argmax()] | |
| iou = get_iou(box, gt_box) | |
| else: | |
| iou = 0.0 | |
| # tqdm.write(f"output: {tokenizer.batch_decode(outputs)}") | |
| tqdm.write(f"no output for: {uniq_id}, {image_id}, {text}") | |
| if iou >= 0.5: | |
| correct += 1 | |
| pbar.set_description(f"iou: {iou:.2f} score: {correct / total:.4f}") | |
| # open_cv_image = np.array(image) | |
| # # Convert RGB to BGR | |
| # open_cv_image = open_cv_image[:, :, ::-1].copy() | |
| # for box, score in zip(boxes, scores): | |
| # open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2) | |
| # cv2.imwrite("output.jpg", open_cv_image) | |
| # print(boxes) | |
| # print(scores) | |
| # exit() | |
| with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f: | |
| f.write(json.dumps([total, correct])) | |
| if world_size > 1: | |
| torch.distributed.barrier() | |
| if rank == 0: | |
| total = 0 | |
| correct = 0 | |
| print(f"evaluate on rank {rank}. world size is {world_size}") | |
| for rank_i in range(world_size): | |
| [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json")) | |
| os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json") | |
| total += total_part | |
| correct += correct_part | |
| score = correct / total | |
| print("score:", score) | |
| with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f: | |
| pass | |
| else: | |
| score = 0.0 | |
| if world_size > 1: | |
| torch.distributed.barrier() | |
| return score | |
| def preprocess_visual_info(Text): | |
| text = Text.split(" ") | |
| for is_idx, t in enumerate(text): | |
| if t == "is": | |
| break | |
| the_idx = is_idx | |
| while text[the_idx] != "the": | |
| the_idx -= 1 | |
| obj_A = " ".join(text[the_idx+1:is_idx]) | |
| second_the_idx = len(text) - 1 | |
| while text[second_the_idx] != "the": | |
| second_the_idx -= 1 | |
| obj_B = " ".join(text[second_the_idx+1:]) | |
| relation = " ".join(text[is_idx+1:second_the_idx]) | |
| visual_obj_A = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>" | |
| visual_obj_B = f"<|#object#|><|#previsual#|><|#prebox#|><|#object#|>the {obj_B}<|#endofobject#|>" | |
| Text = f"{visual_obj_A} is {relation} {visual_obj_B}" | |
| return Text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation | |
| def get_bbox(visual_box_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, mask_prebox, debug=False, return_all=False): | |
| assert isinstance(prompt, list) and len(prompt) == 1 and isinstance(prompt[0], str) | |
| encodings = tokenizer( | |
| prompt, | |
| padding="longest", | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=2000, | |
| ) | |
| input_ids = encodings["input_ids"] | |
| attention_mask = encodings["attention_mask"] | |
| image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
| image_start_index_list = [[x] for x in image_start_index_list] | |
| image_nums = [1] * len(input_ids) | |
| vision_x = batch_images.cuda() | |
| lang_x = input_ids.cuda() | |
| attention_mask = attention_mask.cuda() | |
| prebox_mask = (input_ids == prebox_token_id) | |
| if mask_prebox and prebox_mask.any(): | |
| attention_mask[prebox_mask] = 0 | |
| model.debug_id = 0 | |
| with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16): | |
| outputs = model( | |
| vision_x=vision_x, | |
| lang_x=lang_x, | |
| attention_mask=attention_mask, | |
| labels=None, | |
| image_nums=image_nums, | |
| image_start_index_list=image_start_index_list, | |
| added_bbox_list=visual_box_list, | |
| add_box=visual_box_list is not None, | |
| relations=None, | |
| debug_mode=False, | |
| ) | |
| boxes = outputs["boxes"] | |
| scores = outputs["scores"] | |
| if debug: | |
| import pdb; pdb.set_trace() | |
| if return_all: | |
| return boxes, scores | |
| if len(scores) == 0: | |
| return None, None | |
| else: | |
| return boxes[scores.argmax()], scores.max() | |
| def evaluate_aro( | |
| model, | |
| tokenizer, | |
| image_processor, | |
| batch_size, | |
| tsvfile, | |
| max_generation_length=20, | |
| num_beams=3, | |
| length_penalty=-2.0, | |
| device=-1, | |
| vis_embed_size=None, | |
| rank=0, | |
| world_size=1, | |
| id=0, | |
| add_visual=True, | |
| add_relation=False, | |
| subset=True, | |
| choose_left_right=True, | |
| ): | |
| os.makedirs(f"visualization/aro_results_{id}", exist_ok=True) | |
| from groundingdino.demo.caption_grounder import caption_grounder | |
| generator = caption_grounder( | |
| config_file="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", | |
| checkpoint_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth", | |
| cpu_only=False, | |
| box_threshold=0.1, text_threshold=0.1, | |
| ) | |
| dataset_name = "aro" | |
| media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
| box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1] | |
| endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1] | |
| endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1] | |
| endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
| visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] | |
| previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1] | |
| prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1] | |
| model.eval().cuda() | |
| total = 0 | |
| correct = 0 | |
| from open_flamingo.eval.dataset_zoo import VG_Relation, VG_Attribution | |
| vgr_dataset = VG_Relation(image_preprocess=None, download=True, root_dir="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/vision-language-models-are-bows/data") | |
| with open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/unilm/kosmos-2/labels.json") as f: | |
| all_labels = json.load(f) | |
| label_ids = tokenizer(all_labels).input_ids | |
| label_ids = sorted(list(set([x[0] for x in label_ids]))) | |
| if subset: | |
| subset_idx = json.load(open("aro_subset.json")) | |
| pbar = tqdm(subset_idx, disable=(rank != 0)) | |
| else: | |
| pbar = tqdm(vgr_dataset, disable=(rank != 0)) | |
| exist_total = 0 | |
| for ii, sample in enumerate(pbar): | |
| if subset: | |
| ORI_IDX = int(sample) | |
| sample = vgr_dataset[sample] | |
| # if ORI_IDX != 19036: | |
| # continue | |
| if ii % world_size != rank: | |
| continue | |
| not_left_right = ("near" in sample["caption_options"][0] or "next to" in sample["caption_options"][0] or "in front of" in sample["caption_options"][0] or "behind" in sample["caption_options"][0]) or ("left" not in sample["caption_options"][0] and "right" not in sample["caption_options"][0]) | |
| if (choose_left_right and not_left_right) or (not choose_left_right and not not_left_right): | |
| if rank == 0: | |
| tqdm.write(f"SKIP: {sample['caption_options'][1]}") | |
| continue | |
| total += 1 | |
| image = sample["image_options"][0] | |
| # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB") | |
| image = image.resize((224, 224)) | |
| chosen_idx = 0 | |
| text = sample["caption_options"][chosen_idx] # 1 is true caption | |
| # text = "the dog is sitting on the floor" if idx == 1 else "the floor is sitting on the dog" | |
| batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0) | |
| text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation = preprocess_visual_info(text) | |
| first_text = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|>" | |
| prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{first_text}"] | |
| first_box, first_score = get_bbox(None, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, mask_prebox=True, return_all=False) | |
| # use grounding DINO to get the first bbox | |
| # caption = f"{obj_A}" | |
| # with torch.no_grad(): | |
| # logits, boxes = generator.ground_caption_raw(image_pil=image, caption=caption) | |
| # boxes_filt, pred_phrases = generator.postprocess(logits, boxes, generator.ground_model, caption, generator.text_threshold, generator.box_threshold, with_logits=True) | |
| # objects = {} | |
| # for box, phrase in zip(boxes_filt, pred_phrases): | |
| # obj, score = phrase | |
| # obj = obj[0] | |
| # if obj not in objects: | |
| # objects[obj] = (score, box) | |
| # if objects[obj][0] < score: | |
| # objects[obj] = (score, box) | |
| # try: | |
| # first_box = objects[obj_A][1].clone() | |
| # first_box[:2] -= first_box[2:] / 2 | |
| # first_box[2:] += first_box[:2] | |
| # first_box = first_box.clamp(0, 0.99) * 224.0 | |
| # first_box = first_box.numpy() | |
| # first_score = objects[obj_A][0] | |
| # except: | |
| # first_box = None | |
| if first_box is None: | |
| text_A = "the " + obj_A | |
| added_bbox_list = None | |
| else: | |
| text_A = visual_obj_A | |
| added_bbox_list = [torch.tensor(first_box).unsqueeze(0).cuda() / 224] | |
| prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|>"] | |
| pre_boxes, pre_scores = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id, | |
| prebox_token_id, mask_prebox=False, debug=False, return_all=True) | |
| open_cv_image = np.array(image) | |
| open_cv_image = open_cv_image[:, :, ::-1].copy() | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| fontScale = 0.5 | |
| color = (0, 0, 0) | |
| thickness = 1 | |
| if first_box is not None: | |
| open_cv_image = cv2.rectangle(open_cv_image, first_box[:2].astype(int), first_box[2:].astype(int), (255, 0, 0), 2) | |
| exist_flag = False | |
| for box, score in zip(pre_boxes, pre_scores): | |
| if score >= 0.5: | |
| exist_flag = True | |
| open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (0, 255, 0), 2) | |
| org = box[:2].astype(int) | |
| org[1] += 20 | |
| org[0] += 10 | |
| open_cv_image = cv2.putText(open_cv_image, f"{score:.2f}", org, font, fontScale, (255, 255, 255), thickness, cv2.LINE_AA) | |
| open_cv_image = cv2.resize(open_cv_image, (512, 512)) | |
| put_text = sample["caption_options"][chosen_idx] | |
| org = [10, 20] | |
| open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA) | |
| # cv2.imwrite(f"visualization/aro_results_{id}/{str(ORI_IDX).zfill(8)}.jpg", open_cv_image) | |
| if exist_flag: | |
| exist_total += 1 | |
| continue | |
| if pre_boxes is None: | |
| pre_boxes = [np.array([0.0, 0.0, 223.0, 223.0])] | |
| pre_scores = [1.0] | |
| rank_list = [] | |
| # pre_boxes = [pre_boxes[0]] | |
| # pre_scores = [pre_scores[0]] | |
| for pre_box, pre_score in zip(pre_boxes, pre_scores): | |
| prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|><|#prebox#|><|#object#|> the {obj_B}<|#endofobject#|>"] | |
| encodings = tokenizer( | |
| prompt, | |
| padding="longest", | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=512, | |
| ) | |
| input_ids = encodings["input_ids"] | |
| attention_mask = encodings["attention_mask"] | |
| image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
| image_start_index_list = [[x] for x in image_start_index_list] | |
| image_nums = [1] * len(input_ids) | |
| vision_x = batch_images.cuda() | |
| lang_x = input_ids.cuda() | |
| attention_mask = attention_mask.cuda() | |
| labels = lang_x.clone() | |
| answer_start_idx = (labels == tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]).nonzero()[-1][1] + 1 | |
| # pre_box = None | |
| labels[0, :answer_start_idx] = -100 | |
| # # labels[labels == endofobject_token_id] = -100 | |
| # labels[:, 0] = -100 | |
| # labels[labels == visual_token_id] = -100 | |
| # labels[labels == box_token_id] = -100 | |
| # labels[labels == previsual_token_id] = -100 | |
| # labels[labels == prebox_token_id] = -100 | |
| # labels[labels == endofattr_token_id] = -100 | |
| # labels[labels == tokenizer.pad_token_id] = -100 | |
| # labels[labels == media_token_id] = -100 | |
| # labels[labels == endofmedia_token_id] = -100 | |
| answer_ids = tokenizer(f" {obj_B}", add_special_tokens=False)["input_ids"] | |
| labels[input_ids == visual_token_id] = -100 | |
| labels[input_ids == box_token_id] = -100 | |
| labels[input_ids == endofattr_token_id] = -100 | |
| labels[input_ids == previsual_token_id] = -100 | |
| labels[input_ids == prebox_token_id] = -100 | |
| labels[torch.roll(input_ids == prebox_token_id, 1)] = -100 | |
| labels[torch.roll(input_ids == box_token_id, 1)] = -100 | |
| labels[:, 0] = -100 | |
| labels[input_ids == tokenizer.pad_token_id] = -100 | |
| labels[input_ids == media_token_id] = -100 | |
| labels[input_ids == endofmedia_token_id] = -100 | |
| added_bbox_list = None | |
| if add_visual: | |
| added_bbox_list = [] | |
| if first_box is not None: | |
| added_bbox_list.append(torch.tensor(first_box).unsqueeze(0).cuda().float() / 224) | |
| if pre_box is not None: | |
| added_bbox_list.append(torch.tensor(pre_box).unsqueeze(0).cuda().float() / 224) | |
| if added_bbox_list is not None and len(added_bbox_list) == 0: | |
| added_bbox_list = None | |
| with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad(): | |
| outputs = model( | |
| vision_x=vision_x, | |
| lang_x=lang_x, | |
| attention_mask=attention_mask, | |
| labels=labels, | |
| image_nums=image_nums, | |
| image_start_index_list=image_start_index_list, | |
| added_bbox_list=added_bbox_list, | |
| add_box=added_bbox_list is not None, | |
| relations=None, | |
| ) | |
| logits = outputs["logits"][0, answer_start_idx:] | |
| _rank = logits[0][label_ids].sort(descending=True).indices.tolist().index(label_ids.index(answer_ids[0])) | |
| rank_list.append(_rank) | |
| # open_cv_image = np.array(image) | |
| # open_cv_image = open_cv_image[:, :, ::-1].copy() | |
| # if first_box is not None: | |
| # open_cv_image = cv2.rectangle(open_cv_image, first_box[:2].astype(int), first_box[2:].astype(int), (255, 0, 0), 2) | |
| # if pre_box is not None: | |
| # open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), 2) | |
| # font = cv2.FONT_HERSHEY_SIMPLEX | |
| # org = [10, 20] | |
| # fontScale = 0.5 | |
| # color = (0, 0, 0) | |
| # thickness = 1 | |
| # open_cv_image = cv2.resize(open_cv_image, (512, 512)) | |
| # put_text = sample["caption_options"][1] | |
| # open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA) | |
| # org[1] += 20 | |
| # put_text = "top10 in green box" | |
| # open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA) | |
| # fontScale = 1.0 | |
| # thickness = 2 | |
| # for ind in logits_list[i][0].sort(descending=True).indices[:10]: | |
| # org[1] += 20 | |
| # put_text = f"{tokenizer.decode(ind)}" | |
| # open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA) | |
| # tqdm.write(f"{tokenizer.decode(logits_list[i][0].sort(descending=True).indices[:10])}") | |
| # tqdm.write(f"{rank_list}") | |
| final_rank = min(rank_list) | |
| if final_rank < 10: | |
| correct += 1 | |
| TYPE = "CORRECT" | |
| if rank == 0: | |
| tqdm.write(f"correct: {final_rank} " + prompt[0].replace(tokenizer.pad_token, "")) | |
| else: | |
| TYPE = "WRONG" | |
| if rank == 0: | |
| tqdm.write(f"wrong: {final_rank} " + prompt[0].replace(tokenizer.pad_token, "")) | |
| # cv2.imwrite(f"visualization/aro_results_{id}/{TYPE}_{ORI_IDX}.jpg", open_cv_image) | |
| pbar.set_description(f"score: {correct / total:.4f} | {final_rank}") | |
| print(exist_total) | |
| exit() | |
| with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f: | |
| f.write(json.dumps([total, correct])) | |
| if world_size > 1: | |
| torch.distributed.barrier() | |
| if rank == 0: | |
| total = 0 | |
| correct = 0 | |
| print(f"evaluate on rank {rank}. world size is {world_size}") | |
| for rank_i in range(world_size): | |
| [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json")) | |
| os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json") | |
| total += total_part | |
| correct += correct_part | |
| score = correct / total | |
| print("score:", score, "total:", total) | |
| with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f: | |
| pass | |
| else: | |
| score = 0.0 | |
| if world_size > 1: | |
| torch.distributed.barrier() | |
| return score | |
| def evaluate_aro_ori( | |
| model, | |
| tokenizer, | |
| image_processor, | |
| batch_size, | |
| tsvfile, | |
| max_generation_length=20, | |
| num_beams=3, | |
| length_penalty=-2.0, | |
| device=-1, | |
| vis_embed_size=None, | |
| rank=0, | |
| world_size=1, | |
| id=0, | |
| add_visual=True, | |
| add_relation=False, | |
| subset=True, | |
| choose_left_right=True, | |
| only_highest=True, | |
| ): | |
| os.makedirs(f"visualization/aro_results_{id}", exist_ok=True) | |
| dataset_name = "aroori" | |
| media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
| box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1] | |
| endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1] | |
| endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1] | |
| endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
| visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] | |
| previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1] | |
| prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1] | |
| model.eval().cuda() | |
| total = 0 | |
| correct = 0 | |
| from open_flamingo.eval.dataset_zoo import VG_Relation, VG_Attribution | |
| vgr_dataset = VG_Relation(image_preprocess=None, download=True, root_dir="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/vision-language-models-are-bows/data") | |
| if subset: | |
| subset_idx = json.load(open("aro_subset.json")) | |
| pbar = tqdm(subset_idx, disable=(rank != 0)) | |
| else: | |
| pbar = tqdm(vgr_dataset, disable=(rank != 0)) | |
| for ii, sample in enumerate(pbar): | |
| if subset: | |
| ORI_IDX = int(sample) | |
| sample = vgr_dataset[sample] | |
| # if ORI_IDX != 19036: | |
| # continue | |
| if ii % world_size != rank: | |
| continue | |
| not_left_right = ("near" in sample["caption_options"][0] or "next to" in sample["caption_options"][0] or "in front of" in sample["caption_options"][0] or "behind" in sample["caption_options"][0]) or ("left" not in sample["caption_options"][0] and "right" not in sample["caption_options"][0]) | |
| if (choose_left_right and not_left_right) or (not choose_left_right and not not_left_right): | |
| if rank == 0: | |
| tqdm.write(f"SKIP: {sample['caption_options'][1]}") | |
| continue | |
| total += 1 | |
| image = sample["image_options"][0] | |
| # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB") | |
| image = image.resize((224, 224)) | |
| debug_data = [] | |
| final_losses = [] | |
| for idx in range(2): | |
| text = sample["caption_options"][idx] # 1 is true caption | |
| # text = "the dog is sitting on the floor" if idx == 1 else "the floor is sitting on the dog" | |
| batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0) | |
| text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation = preprocess_visual_info(text) | |
| first_text = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|>" | |
| prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{first_text}"] | |
| first_box, first_score = get_bbox(None, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, mask_prebox=True, return_all=False) | |
| if first_box is None: | |
| text_A = "the " + obj_A | |
| added_bbox_list = None | |
| else: | |
| text_A = visual_obj_A | |
| added_bbox_list = [torch.tensor(first_box).unsqueeze(0).cuda() / 224] | |
| prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|>"] | |
| pre_boxes, pre_scores = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id, | |
| prebox_token_id, mask_prebox=False, debug=False, return_all=True) | |
| if pre_boxes is None: | |
| pre_boxes = [np.array([0.0, 0.0, 223.0, 223.0])] | |
| pre_scores = [1.0] | |
| loss_list = [] | |
| if only_highest: | |
| pre_boxes = [pre_boxes[0]] | |
| pre_scores = [pre_scores[0]] | |
| for pre_box, pre_score in zip(pre_boxes, pre_scores): | |
| prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|><|#prebox#|><|#object#|> the {obj_B}<|#endofobject#|>"] | |
| encodings = tokenizer( | |
| prompt, | |
| padding="longest", | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=512, | |
| ) | |
| input_ids = encodings["input_ids"] | |
| attention_mask = encodings["attention_mask"] | |
| image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
| image_start_index_list = [[x] for x in image_start_index_list] | |
| image_nums = [1] * len(input_ids) | |
| vision_x = batch_images.cuda() | |
| lang_x = input_ids.cuda() | |
| attention_mask = attention_mask.cuda() | |
| labels = lang_x.clone() | |
| labels[input_ids == visual_token_id] = -100 | |
| labels[input_ids == box_token_id] = -100 | |
| labels[input_ids == endofattr_token_id] = -100 | |
| labels[input_ids == previsual_token_id] = -100 | |
| labels[input_ids == prebox_token_id] = -100 | |
| labels[torch.roll(input_ids == prebox_token_id, 1)] = -100 | |
| labels[torch.roll(input_ids == box_token_id, 1)] = -100 | |
| labels[:, 0] = -100 | |
| labels[input_ids == tokenizer.pad_token_id] = -100 | |
| labels[input_ids == media_token_id] = -100 | |
| labels[input_ids == endofmedia_token_id] = -100 | |
| added_bbox_list = None | |
| if add_visual: | |
| added_bbox_list = [] | |
| if first_box is not None: | |
| added_bbox_list.append(torch.tensor(first_box).unsqueeze(0).cuda().float() / 224) | |
| if pre_box is not None: | |
| added_bbox_list.append(torch.tensor(pre_box).unsqueeze(0).cuda().float() / 224) | |
| if added_bbox_list is not None and len(added_bbox_list) == 0: | |
| added_bbox_list = None | |
| with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad(): | |
| outputs = model( | |
| vision_x=vision_x, | |
| lang_x=lang_x, | |
| attention_mask=attention_mask, | |
| labels=labels, | |
| image_nums=image_nums, | |
| image_start_index_list=image_start_index_list, | |
| added_bbox_list=added_bbox_list, | |
| add_box=added_bbox_list is not None, | |
| relations=None, | |
| ) | |
| loss_list.append((outputs["loss"].sum() / (outputs["loss"] != 0).sum()).item()) | |
| debug_data.append([outputs, first_box, first_score, pre_box, pre_scores]) | |
| final_loss = min(loss_list) | |
| final_losses.append(final_loss) | |
| if final_losses[0] >= final_losses[1]: | |
| correct += 1 | |
| else: | |
| import pdb; pdb.set_trace() | |
| pass | |
| pbar.set_description(f"score: {correct / total:.4f} | {final_losses[0]:.2f} vs {final_losses[1]:.2f}") | |
| with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f: | |
| f.write(json.dumps([total, correct])) | |
| if world_size > 1: | |
| torch.distributed.barrier() | |
| if rank == 0: | |
| total = 0 | |
| correct = 0 | |
| print(f"evaluate on rank {rank}. world size is {world_size}") | |
| for rank_i in range(world_size): | |
| [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json")) | |
| os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json") | |
| total += total_part | |
| correct += correct_part | |
| score = correct / total | |
| print("score:", score, "total:", total) | |
| with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f: | |
| pass | |
| else: | |
| score = 0.0 | |
| if world_size > 1: | |
| torch.distributed.barrier() | |
| return score | |
| def evaluate_pisc( | |
| model, | |
| tokenizer, | |
| image_processor, | |
| batch_size, | |
| tsvfile, | |
| max_generation_length=20, | |
| num_beams=3, | |
| length_penalty=-2.0, | |
| device=-1, | |
| vis_embed_size=None, | |
| rank=0, | |
| world_size=1, | |
| id=0, | |
| add_visual=True, | |
| ): | |
| from open_flamingo.train.instruction_template import PISC_TEMPLATES | |
| dataset_name = "pisc" | |
| media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
| box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1] | |
| endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1] | |
| endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1] | |
| endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
| visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] | |
| model.train().cuda() | |
| dataset = wds.WebDataset("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/instruct/eval/pisc/000000.tar").decode().to_tuple("image_path.txt", "dataset.txt", "data.pyd") | |
| pbar = tqdm(dataset, disable=(rank != 0)) | |
| rel_id_to_type = ["friends", "family", "couple", "professional", "commercial", "no relation"] | |
| rel_type_to_id = {x: i for i, x in enumerate(rel_id_to_type)} | |
| gt = [] | |
| pred_scores = [] | |
| for III, sample in enumerate(pbar): | |
| if III % world_size != rank: | |
| continue | |
| image_path, dataset, data = sample | |
| image = Image.open(image_path) | |
| size = image_processor.transforms[0].size | |
| image = image.resize((size, size)) | |
| batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0) | |
| boxA = data[0] | |
| boxB = data[1] | |
| gt_relation = data[2] | |
| losses = [] | |
| for i_rel, option_rel in enumerate(rel_id_to_type): | |
| text = PISC_TEMPLATES[0].format(relation=option_rel) | |
| added_bbox = [ | |
| torch.tensor([boxA]).cuda(), | |
| torch.tensor([boxB]).cuda(), | |
| ] | |
| caption = f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text}{tokenizer.eos_token}" | |
| encodings = tokenizer( | |
| caption, | |
| padding="longest", | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=2000, | |
| ) | |
| input_ids = encodings["input_ids"] | |
| attention_mask = encodings["attention_mask"] | |
| image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
| image_start_index_list = [[x] for x in image_start_index_list] | |
| image_nums = [1] * len(input_ids) | |
| vision_x = batch_images.cuda() | |
| lang_x = input_ids.cuda() | |
| attention_mask = attention_mask.cuda() | |
| labels = lang_x.clone() | |
| labels[labels == tokenizer.pad_token_id] = -100 | |
| if add_visual: | |
| # endofattr_next_token_index = list((labels == endofattr_token_id).nonzero(as_tuple=True)) | |
| # endofattr_next_token_index[1] += 1 | |
| # endofattr_next_token_id = labels[endofattr_next_token_index] | |
| # </obj><visual><box></attr>NEXT_WORD | |
| # </obj> predict NEXT_WORD | |
| # <visual><box></attr> predict nothing | |
| labels[labels == visual_token_id] = -100 | |
| labels[labels == box_token_id] = -100 | |
| labels[labels == endofattr_token_id] = -100 | |
| # labels[endofattr_next_token_index] = -100 | |
| labels[:, 0] = -100 | |
| answer_token_id = tokenizer(" Answer").input_ids[0] | |
| answer_token_loc = (input_ids == answer_token_id).nonzero() | |
| for batch_idx, idx in answer_token_loc: | |
| labels[batch_idx][:idx+2] = -100 | |
| with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad(): | |
| outputs = model( | |
| vision_x=vision_x, | |
| lang_x=lang_x, | |
| attention_mask=attention_mask, | |
| labels=labels, | |
| image_nums=image_nums, | |
| image_start_index_list=image_start_index_list, | |
| added_bbox_list=added_bbox, | |
| add_box=added_bbox is not None, | |
| ) | |
| loss_total = outputs.loss.reshape(labels.shape[0], -1) | |
| loss = loss_total.sum() / (loss_total != 0).sum() | |
| losses.append(loss.item()) | |
| pred_scores.append(np.exp(-np.array(losses)) / np.exp(-np.array(losses)).sum()) | |
| gt.append(rel_type_to_id[gt_relation]) | |
| gt = np.array(gt) | |
| pred_scores = np.array(pred_scores) | |
| pred = pred_scores.argmax(1) | |
| print("total num:", len(gt)) | |
| recalls = recall_score(y_true=gt, y_pred=pred, average=None, labels=[0,1,2,3,4,5]) | |
| print("recalls:", recalls) | |
| with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f: | |
| f.write(json.dumps([gt.tolist(), pred.tolist()])) | |
| if world_size > 1: | |
| torch.distributed.barrier() | |
| if rank == 0: | |
| gt = [] | |
| pred = [] | |
| print(f"evaluate on rank {rank}. world size is {world_size}") | |
| for rank_i in range(world_size): | |
| [gt_part, pred_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json")) | |
| os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json") | |
| gt.extend(gt_part) | |
| pred.extend(pred_part) | |
| print("total num:", len(gt)) | |
| recalls = recall_score(y_true=gt, y_pred=pred, average=None, labels=[0,1,2,3,4,5]) | |
| print("recalls:", recalls) | |
| with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}"), "w") as f: | |
| f.write(f"{gt}\n") | |
| f.write(f"{pred}\n") | |
| f.write(f"{recalls}\n") | |
| score = 0.0 | |
| if world_size > 1: | |
| torch.distributed.barrier() | |
| return score | |
| if __name__ == "__main__": | |
| main() | |