GreenPLM / pointllm /eval /eval_objaverse.py
YuanTang96's picture
1
b30c1d8
import os
import argparse
import torch
import json
from torch.utils.data import DataLoader
from tqdm import tqdm
from pointllm.data import ObjectPointCloudDataset
PROMPT_LISTS = [
"What is this?",
"This is an object of ",
"Caption this 3D model in detail.",
]
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates
from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path
class MyClass:
def __init__(self, arg):
self.vision_tower = None
self.pretrain_mm_mlp_adapter = arg.pretrain_mm_mlp_adapter
self.encoder_type = 'pc_encoder' # text_encoder, pc_encoder
self.std=arg.std
self.pc_encoder_type = arg.pc_encoder_type
self.pc_feat_dim = 192 # 不同的pc encoder 不同
self.embed_dim = 1024
self.group_size = 64
self.num_group =512
self.pc_encoder_dim =512
self.patch_dropout = 0.0
self.pc_ckpt_path = arg.pc_ckpt_path
self.lora_path = arg.lora_path
self.model_path=arg.model_path
self.get_pc_tokens_way=arg.get_pc_tokens_way
def init_model(model_arg_):
model_path = "llava-vicuna_phi_3_finetune_weight"
model_name = get_model_name_from_path(model_path)
model_path = model_arg_.model_path
tokenizer, model, context_len = load_pretrained_model(model_path, None, model_name)
if model_arg_.lora_path:
from peft import PeftModel
model = PeftModel.from_pretrained(model, model_arg_.lora_path)
print("load lora weight ok")
model.get_model().initialize_other_modules(model_arg_)
print("load encoder, mlp ok")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(dtype=torch.bfloat16)
model.get_model().vision_tower.to(dtype=torch.float)
model.to(device)
return tokenizer, model
def load_dataset(data_path, anno_path, pointnum, conversation_types, use_color):
print("Loading validation datasets.")
dataset = ObjectPointCloudDataset(
data_path=data_path,
anno_path=anno_path,
pointnum=pointnum,
conversation_types=conversation_types,
use_color=use_color,
tokenizer=None # * load point cloud only
)
print("Done!")
return dataset
def get_dataloader(dataset, batch_size, shuffle=False, num_workers=4):
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
return dataloader
def start_generation(model, dataloader, annos, prompt_index, output_dir, output_file, tokenizer, args):
qs = PROMPT_LISTS[prompt_index]
results = {"prompt": qs}
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
conv_mode = 'phi3_instruct'
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
qs = conv.get_prompt()
print("qs:",qs)
input_ids = (
tokenizer_image_token(qs, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
.unsqueeze(0)
.cuda()
)
responses = []
for batch in tqdm(dataloader):
point_clouds = batch["point_clouds"].cuda()
object_ids = batch["object_ids"] # * list of string
texts = input_ids.repeat(point_clouds.size()[0], 1)
images_tensor = point_clouds.to(dtype=torch.bfloat16) # torch.Size([20, 8192, 6]
temperature = args.temperature
top_p = args.top_p
max_new_tokens = args.max_new_tokens
min_new_tokens = args.min_new_tokens
num_beams = args.num_beams
repetition_penalty=args.repetition_penalty
with torch.inference_mode():
output_ids = model.generate(
texts,
images=images_tensor,
do_sample=True if temperature > 0 and num_beams == 1 else False,
temperature=temperature,
top_p=top_p,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
use_cache=True,
repetition_penalty=repetition_penalty,
)
answers = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
outputs = []
for answer in answers:
answer = answer.strip()
answer = answer.replace("<|end|>", "").strip()
outputs.append(answer)
# saving results
for obj_id, output in zip(object_ids, outputs):
responses.append({
"object_id": obj_id,
"ground_truth": annos[obj_id],
"model_output": output
})
results["results"] = responses
os.makedirs(output_dir, exist_ok=True)
# save the results to a JSON file
with open(os.path.join(output_dir, output_file), 'w') as fp:
json.dump(results, fp, indent=2)
# * print info
print(f"Saved results to {os.path.join(output_dir, output_file)}")
return results
def main(args):
# * ouptut
args.output_dir = os.path.join(args.out_path, "evaluation")
# * output file
anno_file = os.path.splitext(os.path.basename(args.anno_path))[0]
args.output_file = f"{anno_file}_Objaverse_{args.task_type}_prompt{args.prompt_index}.json"
args.output_file_path = os.path.join(args.output_dir, args.output_file)
# * First inferencing, then evaluate
if not os.path.exists(args.output_file_path):
# * need inferencing
# * load annotation files
with open(args.anno_path, 'r') as fp:
annos = json.load(fp)
dataset = load_dataset(args.data_path, args.anno_path, args.pointnum, ("simple_description",), args.use_color)
dataloader = get_dataloader(dataset, args.batch_size, args.shuffle, args.num_workers)
model_arg = MyClass(args)
tokenizer, model = init_model(model_arg)
model.eval()
# * convert annos file from [{"object_id": }] to {"object_id": }
annos = {anno["object_id"]: anno["conversations"][1]['value'] for anno in annos}
print(f'[INFO] Start generating results for {args.output_file}.')
results = start_generation(model, dataloader, annos, args.prompt_index, args.output_dir, args.output_file, tokenizer, args)
# * release model and release cuda memory
del model
torch.cuda.empty_cache()
else:
# * directly load the results
print(f'[INFO] {args.output_file_path} already exists, directly loading...')
with open(args.output_file_path, 'r') as fp:
results = json.load(fp)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--out_path", type=str, default="./output_json")
parser.add_argument("--pretrain_mm_mlp_adapter", type=str, required=True)
parser.add_argument("--lora_path", type=str, default=None)
parser.add_argument("--model_path", type=str, default='./lava-vicuna_2024_4_Phi-3-mini-4k-instruct')
parser.add_argument("--std", type=float, default=0.0)
parser.add_argument("--pc_ckpt_path", type=str, required=True, default="./pretrained_weight/Uni3D_PC_encoder/modelzoo/uni3d-small/model.pt")
parser.add_argument("--pc_encoder_type", type=str, required=True, default='small')
parser.add_argument("--get_pc_tokens_way", type=str, required=True)
# * dataset type
parser.add_argument("--data_path", type=str, default="./dataset/Objaverse/8192_npy", required=False)
parser.add_argument("--anno_path", type=str,
default="./dataset/Objaverse/PointLLM_brief_description_val_200_GT.json",
required=False)
parser.add_argument("--pointnum", type=int, default=8192)
parser.add_argument("--use_color", action="store_true", default=True)
# * data loader, batch_size, shuffle, num_workers
parser.add_argument("--batch_size", type=int, default=10)
parser.add_argument("--shuffle", type=bool, default=False)
parser.add_argument("--num_workers", type=int, default=10)
# * evaluation setting
parser.add_argument("--prompt_index", type=int, default=0)
parser.add_argument("--task_type", type=str, default="classification", choices=["captioning", "classification"],
help="Type of the task to evaluate.")
############## new add
parser.add_argument("--max_new_tokens", type=int, default=150, help="max number of generated tokens")
parser.add_argument("--min_new_tokens", type=int, default=0, help="min number of generated tokens")
parser.add_argument("--num_beams", type=int, default=1)
parser.add_argument("--temperature", type=float, default=0.1)
parser.add_argument("--top_k", type=int, default=1) # 暂时没起作用
parser.add_argument("--top_p", type=float, default=0.7)
parser.add_argument("--repetition_penalty", type=float, default=1 )
############## new add
args = parser.parse_args()
# * check prompt index
# * * classification: 0, 1 and captioning: 2. Raise Warning otherwise.
if args.task_type == "classification":
if args.prompt_index != 0 and args.prompt_index != 1:
print("[Warning] For classification task, prompt_index should be 0 or 1.")
elif args.task_type == "captioning":
pass
if args.prompt_index != 2:
print("[Warning] For captioning task, prompt_index should be 2.")
else:
raise NotImplementedError
main(args)