File size: 3,303 Bytes
32b542e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li
@contact: [email protected]
"""
import os
import sys
import pickle
import json
from json import encoder
from .build import EVALUATION_REGISTRY

@EVALUATION_REGISTRY.register()
class VQAEvaler(object):
    def __init__(self, cfg, annfile, output_dir):
        super(VQAEvaler, self).__init__()
        label2ans_path = os.path.join(cfg.DATALOADER.ANNO_FOLDER, "trainval_label2ans.pkl")
        ori_annotation = json.load(open(os.path.join(cfg.DATALOADER.ANNO_FOLDER, "v2_mscoco_val2014_annotations.json")))
        self.id2type = {t['question_id']: t['question_type'] for t in ori_annotation['annotations']}
        self.id2type_answer = {t['question_id']: t['answer_type'] for t in ori_annotation['annotations']}
        self.label2ans = pickle.load(open(label2ans_path, "rb"))

        self.id2label = {}
        if len(annfile) > 0:
            answers_val = pickle.load(open(annfile, "rb"))
            for datum in answers_val:
                quesid = datum['question_id']
                self.id2label[quesid] = {}
                for i, label in enumerate(datum['labels']):
                    label_str = self.label2ans[label]
                    self.id2label[quesid][label_str] = datum['scores'][i]

        if output_dir is not None:
            self.output_dir = os.path.join(output_dir, 'results')
            if not os.path.exists(self.output_dir):
                os.mkdir(self.output_dir)
        else:
            self.output_dir = None

    def eval(self, results, epoch):
        for res in results:
            res['answer'] = self.label2ans[res['answer']]
        if self.output_dir is not None:
            json.dump(results, open(os.path.join(self.output_dir, str(epoch) + '.json'), "w", encoding="utf-8"))

        accuracy = 0.
        acc_by_type = dict()
        acc_by_type_answer = dict()
        for result in results:
            quesid = result['question_id']
            ans = result['answer']
            if quesid not in self.id2type:
                print("Test Stage has no target")
                return { "accuracy": 0.0 }
            type = self.id2type[quesid]
            ans_type = self.id2type_answer[quesid]
            if type not in acc_by_type:
                acc_by_type[type] = [0,0]
            if ans_type not in acc_by_type_answer:
                acc_by_type_answer[ans_type] = [0,0]
            if quesid not in self.id2label:
                return { "accuracy": 0.0 }

            datum = self.id2label[quesid]
            acc_by_type[type][1] += 1
            acc_by_type_answer[ans_type][1] += 1
            if ans in datum:
                accuracy += datum[ans]
                acc_by_type[type][0] += datum[ans]
                acc_by_type_answer[ans_type][0] += datum[ans]

        accuracy = accuracy / len(results)
        print('vqa acc: {}'.format(accuracy*100))
        for k in acc_by_type.keys():
            acc_by_type[k] = acc_by_type[k][0] / acc_by_type[k][1] * 100
        for k in acc_by_type_answer.keys():
            acc_by_type_answer[k] = acc_by_type_answer[k][0] / acc_by_type_answer[k][1] * 100
        print('vqa acc by question type: ', acc_by_type)
        print("vqa acc by answer type: ", acc_by_type_answer)
        return { "accuracy": accuracy }