|
import os |
|
import argparse |
|
import torch |
|
import json |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
|
|
from pointllm.data import ObjectPointCloudDataset |
|
|
|
|
|
PROMPT_LISTS = [ |
|
"What is this?", |
|
"This is an object of ", |
|
"Caption this 3D model in detail.", |
|
] |
|
|
|
|
|
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN |
|
from llava.conversation import conv_templates |
|
from llava.model.builder import load_pretrained_model |
|
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path |
|
|
|
|
|
class MyClass: |
|
|
|
def __init__(self, arg): |
|
|
|
self.vision_tower = None |
|
self.pretrain_mm_mlp_adapter = arg.pretrain_mm_mlp_adapter |
|
|
|
self.encoder_type = 'pc_encoder' |
|
self.std=arg.std |
|
|
|
self.pc_encoder_type = arg.pc_encoder_type |
|
self.pc_feat_dim = 192 |
|
self.embed_dim = 1024 |
|
self.group_size = 64 |
|
self.num_group =512 |
|
self.pc_encoder_dim =512 |
|
self.patch_dropout = 0.0 |
|
self.pc_ckpt_path = arg.pc_ckpt_path |
|
self.lora_path = arg.lora_path |
|
self.model_path=arg.model_path |
|
self.get_pc_tokens_way=arg.get_pc_tokens_way |
|
|
|
|
|
def init_model(model_arg_): |
|
model_path = "llava-vicuna_phi_3_finetune_weight" |
|
model_name = get_model_name_from_path(model_path) |
|
model_path = model_arg_.model_path |
|
tokenizer, model, context_len = load_pretrained_model(model_path, None, model_name) |
|
|
|
if model_arg_.lora_path: |
|
from peft import PeftModel |
|
|
|
model = PeftModel.from_pretrained(model, model_arg_.lora_path) |
|
print("load lora weight ok") |
|
|
|
model.get_model().initialize_other_modules(model_arg_) |
|
print("load encoder, mlp ok") |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
model.to(dtype=torch.bfloat16) |
|
model.get_model().vision_tower.to(dtype=torch.float) |
|
model.to(device) |
|
|
|
return tokenizer, model |
|
|
|
|
|
|
|
def load_dataset(data_path, anno_path, pointnum, conversation_types, use_color): |
|
print("Loading validation datasets.") |
|
dataset = ObjectPointCloudDataset( |
|
data_path=data_path, |
|
anno_path=anno_path, |
|
pointnum=pointnum, |
|
conversation_types=conversation_types, |
|
use_color=use_color, |
|
tokenizer=None |
|
) |
|
print("Done!") |
|
return dataset |
|
|
|
|
|
def get_dataloader(dataset, batch_size, shuffle=False, num_workers=4): |
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) |
|
return dataloader |
|
|
|
|
|
def start_generation(model, dataloader, annos, prompt_index, output_dir, output_file, tokenizer, args): |
|
qs = PROMPT_LISTS[prompt_index] |
|
|
|
results = {"prompt": qs} |
|
|
|
|
|
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs |
|
|
|
conv_mode = 'phi3_instruct' |
|
conv = conv_templates[conv_mode].copy() |
|
conv.append_message(conv.roles[0], qs) |
|
conv.append_message(conv.roles[1], None) |
|
qs = conv.get_prompt() |
|
|
|
print("qs:",qs) |
|
|
|
|
|
input_ids = ( |
|
tokenizer_image_token(qs, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") |
|
.unsqueeze(0) |
|
.cuda() |
|
) |
|
|
|
|
|
responses = [] |
|
|
|
for batch in tqdm(dataloader): |
|
point_clouds = batch["point_clouds"].cuda() |
|
object_ids = batch["object_ids"] |
|
|
|
texts = input_ids.repeat(point_clouds.size()[0], 1) |
|
|
|
images_tensor = point_clouds.to(dtype=torch.bfloat16) |
|
|
|
|
|
temperature = args.temperature |
|
top_p = args.top_p |
|
|
|
max_new_tokens = args.max_new_tokens |
|
min_new_tokens = args.min_new_tokens |
|
num_beams = args.num_beams |
|
repetition_penalty=args.repetition_penalty |
|
|
|
|
|
with torch.inference_mode(): |
|
output_ids = model.generate( |
|
texts, |
|
images=images_tensor, |
|
do_sample=True if temperature > 0 and num_beams == 1 else False, |
|
temperature=temperature, |
|
top_p=top_p, |
|
num_beams=num_beams, |
|
max_new_tokens=max_new_tokens, |
|
min_new_tokens=min_new_tokens, |
|
use_cache=True, |
|
repetition_penalty=repetition_penalty, |
|
) |
|
|
|
|
|
|
|
answers = tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
|
|
|
outputs = [] |
|
for answer in answers: |
|
answer = answer.strip() |
|
answer = answer.replace("<|end|>", "").strip() |
|
outputs.append(answer) |
|
|
|
|
|
for obj_id, output in zip(object_ids, outputs): |
|
responses.append({ |
|
"object_id": obj_id, |
|
"ground_truth": annos[obj_id], |
|
"model_output": output |
|
}) |
|
|
|
results["results"] = responses |
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
with open(os.path.join(output_dir, output_file), 'w') as fp: |
|
json.dump(results, fp, indent=2) |
|
|
|
|
|
print(f"Saved results to {os.path.join(output_dir, output_file)}") |
|
|
|
return results |
|
|
|
|
|
def main(args): |
|
|
|
args.output_dir = os.path.join(args.out_path, "evaluation") |
|
|
|
|
|
anno_file = os.path.splitext(os.path.basename(args.anno_path))[0] |
|
args.output_file = f"{anno_file}_Objaverse_{args.task_type}_prompt{args.prompt_index}.json" |
|
args.output_file_path = os.path.join(args.output_dir, args.output_file) |
|
|
|
|
|
if not os.path.exists(args.output_file_path): |
|
|
|
|
|
with open(args.anno_path, 'r') as fp: |
|
annos = json.load(fp) |
|
|
|
dataset = load_dataset(args.data_path, args.anno_path, args.pointnum, ("simple_description",), args.use_color) |
|
dataloader = get_dataloader(dataset, args.batch_size, args.shuffle, args.num_workers) |
|
|
|
model_arg = MyClass(args) |
|
tokenizer, model = init_model(model_arg) |
|
model.eval() |
|
|
|
|
|
annos = {anno["object_id"]: anno["conversations"][1]['value'] for anno in annos} |
|
|
|
print(f'[INFO] Start generating results for {args.output_file}.') |
|
results = start_generation(model, dataloader, annos, args.prompt_index, args.output_dir, args.output_file, tokenizer, args) |
|
|
|
|
|
del model |
|
|
|
torch.cuda.empty_cache() |
|
else: |
|
|
|
print(f'[INFO] {args.output_file_path} already exists, directly loading...') |
|
with open(args.output_file_path, 'r') as fp: |
|
results = json.load(fp) |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--out_path", type=str, default="./output_json") |
|
parser.add_argument("--pretrain_mm_mlp_adapter", type=str, required=True) |
|
|
|
parser.add_argument("--lora_path", type=str, default=None) |
|
parser.add_argument("--model_path", type=str, default='./lava-vicuna_2024_4_Phi-3-mini-4k-instruct') |
|
|
|
parser.add_argument("--std", type=float, default=0.0) |
|
parser.add_argument("--pc_ckpt_path", type=str, required=True, default="./pretrained_weight/Uni3D_PC_encoder/modelzoo/uni3d-small/model.pt") |
|
parser.add_argument("--pc_encoder_type", type=str, required=True, default='small') |
|
parser.add_argument("--get_pc_tokens_way", type=str, required=True) |
|
|
|
|
|
parser.add_argument("--data_path", type=str, default="./dataset/Objaverse/8192_npy", required=False) |
|
|
|
parser.add_argument("--anno_path", type=str, |
|
default="./dataset/Objaverse/PointLLM_brief_description_val_200_GT.json", |
|
required=False) |
|
parser.add_argument("--pointnum", type=int, default=8192) |
|
parser.add_argument("--use_color", action="store_true", default=True) |
|
|
|
|
|
parser.add_argument("--batch_size", type=int, default=10) |
|
parser.add_argument("--shuffle", type=bool, default=False) |
|
parser.add_argument("--num_workers", type=int, default=10) |
|
|
|
|
|
parser.add_argument("--prompt_index", type=int, default=0) |
|
|
|
parser.add_argument("--task_type", type=str, default="classification", choices=["captioning", "classification"], |
|
help="Type of the task to evaluate.") |
|
|
|
|
|
|
|
parser.add_argument("--max_new_tokens", type=int, default=150, help="max number of generated tokens") |
|
parser.add_argument("--min_new_tokens", type=int, default=0, help="min number of generated tokens") |
|
parser.add_argument("--num_beams", type=int, default=1) |
|
parser.add_argument("--temperature", type=float, default=0.1) |
|
parser.add_argument("--top_k", type=int, default=1) |
|
parser.add_argument("--top_p", type=float, default=0.7) |
|
parser.add_argument("--repetition_penalty", type=float, default=1 ) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
if args.task_type == "classification": |
|
if args.prompt_index != 0 and args.prompt_index != 1: |
|
print("[Warning] For classification task, prompt_index should be 0 or 1.") |
|
elif args.task_type == "captioning": |
|
pass |
|
if args.prompt_index != 2: |
|
print("[Warning] For captioning task, prompt_index should be 2.") |
|
else: |
|
raise NotImplementedError |
|
|
|
main(args) |
|
|
|
|
|
|