Spaces:
Running
on
A100
Running
on
A100
File size: 2,282 Bytes
174ae06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
# This file is originated from the official MMMU codebase:
# https://github.com/MMMU-Benchmark/MMMU
import random
import numpy as np
def parse_choice(response, all_choices, index2ans=None):
"""
Parse the prediction from the generated response.
Return the predicted index e.g., A, B, C, D.
"""
for char in [",", ".", "!", "?", ";", ":", "'"]:
response = response.strip(char)
response = " " + response + " " # add space to avoid partial match
index_ans = True
ans_with_brack = False
candidates = []
for choice in all_choices: # e.g., (A) (B) (C) (D)
if f"({choice})" in response:
candidates.append(choice)
ans_with_brack = True
if len(candidates) == 0:
for choice in all_choices: # e.g., A B C D
if f" {choice} " in response:
candidates.append(choice)
# if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
if len(candidates) == 0 and len(response.split()) > 5 and index2ans is not None:
for index, ans in index2ans.items():
if ans.lower() in response.lower():
candidates.append(index)
index_ans = False # it's content ans.
if len(candidates) == 0: # still not get answer, randomly choose one.
pred_index = random.choice(all_choices)
elif len(candidates) > 1:
start_indexes = []
if index_ans:
if ans_with_brack:
for can in candidates:
index = response.rfind(f"({can})")
start_indexes.append(index) # -1 will be ignored anyway
# start_indexes = [generated_response.index(f'({can})') for can in candidates]
else:
for can in candidates:
index = response.rfind(f" {can} ")
start_indexes.append(index)
else:
for can in candidates:
index = response.lower().rfind(index2ans[can].lower())
start_indexes.append(index)
# get the last one
pred_index = candidates[np.argmax(start_indexes)]
else: # if only one candidate, use it.
pred_index = candidates[0]
return pred_index
|