import argparse import os import shutil import sys import time from functools import partial import deepspeed import numpy as np import torch import tqdm import transformers from torch.utils.tensorboard import SummaryWriter from model.LISA import LISA from utils.dataset import HybridDataset, ValDataset, collate_fn from utils.utils import ( AverageMeter, ProgressMeter, Summary, dict_to_cuda, intersectionAndUnionGPU, ) def parse_args(args): parser = argparse.ArgumentParser(description="LISA Model Training") parser.add_argument("--local_rank", default=0, type=int, help="node rank") parser.add_argument( "--version", default="liuhaotian/llava-llama-2-13b-chat-lightning-preview" ) parser.add_argument("--vis_save_path", default="./vis_output", type=str) parser.add_argument( "--precision", default="bf16", type=str, choices=["fp32", "bf16", "fp16"], help="precision for inference", ) parser.add_argument("--image_size", default=1024, type=int, help="image size") parser.add_argument("--model_max_length", default=512, type=int) parser.add_argument("--lora_r", default=8, type=int) parser.add_argument( "--vision-tower", default="openai/clip-vit-large-patch14", type=str ) parser.add_argument("--load_in_8bit", action="store_true", default=False) parser.add_argument("--load_in_4bit", action="store_true", default=False) parser.add_argument( "--dataset", default="sem_seg||refer_seg||vqa||reason_seg", type=str ) parser.add_argument( "--sem_seg_data", default="ade20k||cocostuff||pascal_part||paco_lvis||mapillary", type=str, ) parser.add_argument( "--refer_seg_data", default="refclef||refcoco||refcoco+||refcocog", type=str ) parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str) parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str) parser.add_argument("--val_dataset", default="ReasonSeg|val", type=str) parser.add_argument("--dataset_dir", default="./dataset", type=str) parser.add_argument("--log_base_dir", default="./runs", type=str) parser.add_argument("--exp_name", default="lisa", type=str) parser.add_argument("--epochs", default=10, type=int) parser.add_argument("--steps_per_epoch", default=500, type=int) parser.add_argument( "--batch_size", default=2, type=int, help="batch size per device per step" ) parser.add_argument( "--grad_accumulation_steps", default=10, type=int, ) parser.add_argument("--val_batch_size", default=1, type=int) parser.add_argument("--workers", default=4, type=int) parser.add_argument("--lr", default=0.0003, type=float) parser.add_argument("--ce_loss_weight", default=1.0, type=float) parser.add_argument("--dice_loss_weight", default=0.5, type=float) parser.add_argument("--bce_loss_weight", default=2.0, type=float) parser.add_argument("--lora_alpha", default=16, type=int) parser.add_argument("--lora_dropout", default=0.05, type=float) parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str) parser.add_argument("--explanatory", default=0.1, type=float) parser.add_argument("--beta1", default=0.9, type=float) parser.add_argument("--beta2", default=0.95, type=float) parser.add_argument("--num_classes_per_sample", default=3, type=int) parser.add_argument("--exclude_val", action="store_true", default=False) parser.add_argument("--no_eval", action="store_true", default=False) parser.add_argument("--eval_only", action="store_true", default=False) parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str) parser.add_argument("--weight", default="", type=str) parser.add_argument("--print_freq", default=1, type=int) parser.add_argument("--start_epoch", default=0, type=int) return parser.parse_args(args) def main(args): args = parse_args(args) args.log_dir = os.path.join(args.log_base_dir, args.exp_name) if args.local_rank == 0: os.makedirs(args.log_dir, exist_ok=True) writer = SummaryWriter(args.log_dir) else: writer = None # Create model tokenizer = transformers.AutoTokenizer.from_pretrained( args.version, cache_dir=None, model_max_length=args.model_max_length, padding_side="right", use_fast=False, ) tokenizer.pad_token = tokenizer.unk_token num_added_tokens = tokenizer.add_tokens("[SEG]") ret_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids args.seg_token_idx = ret_token_idx[0] model = LISA( args.local_rank, args.seg_token_idx, tokenizer, args.version, args.lora_r, args.precision, vision_tower=args.vision_tower, load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit, ce_loss_weight=args.ce_loss_weight, dice_loss_weight=args.dice_loss_weight, bce_loss_weight=args.bce_loss_weight, vision_pretrained=args.vision_pretrained, ) if args.weight: state_dict = torch.load(args.weight, map_location="cpu") model.load_state_dict(state_dict, strict=True) world_size = torch.cuda.device_count() args.distributed = world_size > 1 train_dataset = HybridDataset( args.dataset_dir, tokenizer, args.vision_tower, samples_per_epoch=args.batch_size * args.grad_accumulation_steps * args.steps_per_epoch * world_size, precision=args.precision, image_size=args.image_size, num_classes_per_sample=args.num_classes_per_sample, exclude_val=args.exclude_val, dataset=args.dataset, sem_seg_data=args.sem_seg_data, refer_seg_data=args.refer_seg_data, vqa_data=args.vqa_data, reason_seg_data=args.reason_seg_data, explanatory=args.explanatory, ) if args.no_eval == False: val_dataset = ValDataset( args.dataset_dir, tokenizer, args.vision_tower, args.val_dataset, args.image_size, ) print( f"Training with {len(train_dataset)} examples and validating with {len(val_dataset)} examples." ) else: val_dataset = None print(f"Training with {len(train_dataset)} examples.") ds_config = { "train_micro_batch_size_per_gpu": args.batch_size, "gradient_accumulation_steps": args.grad_accumulation_steps, "optimizer": { "type": "AdamW", "params": { "lr": args.lr, "weight_decay": 0.0, "betas": (args.beta1, args.beta2), }, }, "scheduler": { "type": "WarmupDecayLR", "params": { "total_num_steps": args.epochs * args.steps_per_epoch, "warmup_min_lr": 0, "warmup_max_lr": args.lr, "warmup_num_steps": 100, "warmup_type": "linear", }, }, "fp16": { "enabled": args.precision == "fp16", }, "bf16": { "enabled": args.precision == "bf16", }, "gradient_clipping": 1.0, "zero_optimization": { "stage": 2, "contiguous_gradients": True, "overlap_comm": True, "reduce_scatter": True, "reduce_bucket_size": 5e8, "allgather_bucket_size": 5e8, }, } model_engine, optimizer, train_loader, scheduler = deepspeed.initialize( model=model, model_parameters=model.parameters(), training_data=train_dataset, collate_fn=partial(collate_fn, tokenizer=tokenizer), config=ds_config, ) if val_dataset is not None: assert args.val_batch_size == 1 val_sampler = torch.utils.data.distributed.DistributedSampler( val_dataset, shuffle=False, drop_last=False ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.val_batch_size, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=val_sampler, collate_fn=partial(collate_fn, tokenizer=tokenizer), ) train_iter = iter(train_loader) best_score, cur_ciou = 0.0, 0.0 if args.eval_only: giou, ciou = validate(val_loader, model_engine, 0, writer, args) exit() for epoch in range(args.start_epoch, args.epochs): # train for one epoch train_iter = train( train_loader, model_engine, epoch, scheduler, writer, train_iter, args, ) if args.no_eval == False: giou, ciou = validate(val_loader, model_engine, epoch, writer, args) is_best = giou > best_score best_score = max(giou, best_score) cur_ciou = ciou if is_best else cur_ciou if args.no_eval or is_best: save_dir = os.path.join(args.log_dir, "ckpt_model") if args.local_rank == 0: torch.save( {"epoch": epoch}, os.path.join( args.log_dir, "meta_log_giou{:.3f}_ciou{:.3f}.pth".format( best_score, cur_ciou ), ), ) if os.path.exists(save_dir): shutil.rmtree(save_dir) torch.distributed.barrier() model_engine.save_checkpoint(save_dir) def train( train_loader, model, epoch, scheduler, writer, train_iter, args, ): """Main training loop.""" batch_time = AverageMeter("Time", ":6.3f") data_time = AverageMeter("Data", ":6.3f") losses = AverageMeter("Loss", ":.4f") ce_losses = AverageMeter("CeLoss", ":.4f") mask_bce_losses = AverageMeter("MaskBCELoss", ":.4f") mask_dice_losses = AverageMeter("MaskDICELoss", ":.4f") mask_losses = AverageMeter("MaskLoss", ":.4f") progress = ProgressMeter( args.steps_per_epoch, [ batch_time, losses, ce_losses, mask_losses, mask_bce_losses, mask_dice_losses, ], prefix="Epoch: [{}]".format(epoch), ) # switch to train mode model.train() end = time.time() for global_step in range(args.steps_per_epoch): for i in range(args.grad_accumulation_steps): try: input_dict = next(train_iter) except: train_iter = iter(train_loader) input_dict = next(train_iter) data_time.update(time.time() - end) input_dict = dict_to_cuda(input_dict) if args.precision == "fp16": input_dict["images"] = input_dict["images"].half() input_dict["images_clip"] = input_dict["images_clip"].half() elif args.precision == "bf16": input_dict["images"] = input_dict["images"].bfloat16() input_dict["images_clip"] = input_dict["images_clip"].bfloat16() else: input_dict["images"] = input_dict["images"].float() input_dict["images_clip"] = input_dict["images_clip"].float() output_dict = model(**input_dict) loss = output_dict["loss"] ce_loss = output_dict["ce_loss"] mask_bce_loss = output_dict["mask_bce_loss"] mask_dice_loss = output_dict["mask_dice_loss"] mask_loss = output_dict["mask_loss"] losses.update(loss.item(), input_dict["images"].size(0)) ce_losses.update(ce_loss.item(), input_dict["images"].size(0)) mask_bce_losses.update(mask_bce_loss.item(), input_dict["images"].size(0)) mask_dice_losses.update(mask_dice_loss.item(), input_dict["images"].size(0)) mask_losses.update(mask_loss.item(), input_dict["images"].size(0)) model.backward(loss) model.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if global_step % args.print_freq == 0: if args.distributed: batch_time.all_reduce() data_time.all_reduce() losses.all_reduce() ce_losses.all_reduce() mask_bce_losses.all_reduce() mask_dice_losses.all_reduce() mask_losses.all_reduce() if args.local_rank == 0: progress.display(global_step + 1) writer.add_scalar("train/loss", losses.avg, global_step) writer.add_scalar("train/ce_loss", ce_losses.avg, global_step) writer.add_scalar( "train/mask_bce_loss", mask_bce_losses.avg, global_step ) writer.add_scalar( "train/mask_dice_loss", mask_dice_losses.avg, global_step ) writer.add_scalar("train/mask_loss", mask_losses.avg, global_step) writer.add_scalar( "metrics/total_secs_per_batch", batch_time.avg, global_step ) writer.add_scalar( "metrics/data_secs_per_batch", data_time.avg, global_step ) batch_time.reset() data_time.reset() losses.reset() ce_losses.reset() mask_bce_losses.reset() mask_dice_losses.reset() mask_losses.reset() if global_step != 0: curr_lr = scheduler.get_last_lr() if args.local_rank == 0: writer.add_scalar("train/lr", curr_lr[0], global_step) return train_iter def validate(val_loader, model_engine, epoch, writer, args): intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM) union_meter = AverageMeter("Union", ":6.3f", Summary.SUM) acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM) model_engine.eval() for input_dict in tqdm.tqdm(val_loader): input_dict = dict_to_cuda(input_dict) if args.precision == "fp16": input_dict["images"] = input_dict["images"].half() input_dict["images_clip"] = input_dict["images_clip"].half() elif args.precision == "bf16": input_dict["images"] = input_dict["images"].bfloat16() input_dict["images_clip"] = input_dict["images_clip"].bfloat16() else: input_dict["images"] = input_dict["images"].float() input_dict["images_clip"] = input_dict["images_clip"].float() output_dict = model_engine(**input_dict) pred_masks = output_dict["pred_masks"] masks_list = output_dict["gt_masks"][0].int() output_list = (pred_masks[0] > 0).int() assert len(pred_masks) == 1 intersection, union, acc_iou = 0.0, 0.0, 0.0 for mask_i, output_i in zip(masks_list, output_list): intersection_i, union_i, _ = intersectionAndUnionGPU( output_i.contiguous().clone(), mask_i.contiguous(), 2, ignore_index=255 ) intersection += intersection_i union += union_i acc_iou += intersection_i / (union_i + 1e-5) acc_iou[union_i == 0] += 1.0 # no-object target intersection, union = intersection.cpu().numpy(), union.cpu().numpy() acc_iou = acc_iou.cpu().numpy() / masks_list.shape[0] intersection_meter.update(intersection), union_meter.update( union ), acc_iou_meter.update(acc_iou, n=masks_list.shape[0]) intersection_meter.all_reduce() union_meter.all_reduce() acc_iou_meter.all_reduce() iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) ciou = iou_class[1] giou = acc_iou_meter.avg[1] if args.local_rank == 0: writer.add_scalar("val/giou", giou, epoch) writer.add_scalar("val/giou", ciou, epoch) print("giou: {:.4f}, ciou: {:.4f}".format(giou, ciou)) return giou, ciou if __name__ == "__main__": main(sys.argv[1:])