File size: 15,469 Bytes
b30c1d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 |
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# import os
# os.environ["CUDA_VISIBLE_DEVICES"]="2"
from dataclasses import dataclass, field
import pathlib
from typing import Optional, List
import torch
import transformers
from pointllm.train.pointllm_trainer import PointLLMTrainer
from pointllm import conversation as conversation_lib
from pointllm.model import *
from pointllm.data import make_object_point_data_module
# * logger
from pointllm.utils import build_logger
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "<unk>"
@dataclass
class ModelArguments:
# /home/pointllm_weight_2/PointLLM_7B_v1.2
model_name_or_path: Optional[str] = field(default="/home/TinyGPT-V/pretrain_weight/phi-new")
# model_name_or_path: Optional[str] = field(default="/home/pointllm_weight_2/PointLLM_7B_v1.2")
version: Optional[str] = field(default="v1")
@dataclass
class DataArguments:
data_path: str = field(default="/home/PointLLM/data/objaverse_data", metadata={"help": "Path to the training data."})
anno_path: str = field(default='/home/PointLLM/data/anno_data/PointLLM_complex_instruction_70K.json', metadata={"help": "Path to the utterance data. If None, will use referit3d by defautl."})
use_color: bool = field(default=True, metadata={"help": "Whether to use color."})
data_debug_num: int = field(default=0, metadata={"help": "Number of data to use in debug mode. If larger than 0, use debug mode, else use the whole data"})
split_train_val: bool = field(default=False, metadata={"help": "Whether to split train and val."})
split_ratio: float = field(default=0.9, metadata={"help": "Ratio of train and val."})
pointnum: int = field(default=8192, metadata={"help": "Number of points."})
# conversation_types: List[str] = field(default_factory=lambda: ["simple_description"], metadata={"help": "Conversation types to use."})
conversation_types: List[str] = field(default_factory=lambda: ["detailed_description", "single_round", "multi_round"],
metadata={"help": "Conversation types to use."})
is_multimodal: bool = True
@dataclass
class TrainingArguments(transformers.TrainingArguments):
# * can refer to https://huggingface.co/docs/transformers/v4.28.1/en/main_classes/trainer#transformers.TrainingArgument
cache_dir: Optional[str] = field(default='/home/PointLLM/trash')
output_dir: Optional[str] = field(default='/home/PointLLM/trash')
save_strategy: Optional[str] = field(default='no')
save_steps: int = field(default=2400)
optim: str = field(default="adamw_torch")
dataloader_num_workers: int = field(default=24)
model_max_length: int = field(
default=2048,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
per_device_train_batch_size: int = field(
default=6, metadata={"help": "Batch size per GPU/TPU/MPS/NPU core/CPU for training."}
)
model_debug: bool = field(default=False, metadata={"help": "Whether to use small model."}) # * whether to load checkpoints at the mo
fix_llm: bool = field(default=True, metadata={"help": "Whether to fix the LLM."})
fix_pointnet: bool = field(default=True, metadata={"help": "Whether to fix the PointNet."})
remove_unused_columns: bool = field(default=False)
force_fsdp: bool = field(default=False)
bf16: bool = field(default=True)
# * for two stage training
tune_mm_mlp_adapter: bool = field(default=True) # * set True when pre-training, and false when fine-tuning
stage_2: bool = field(default=False) # * set True when fine-tuning
pretrained_mm_mlp_adapter: Optional[str] = field(default=None) # * path to the pre-trained projector & output_embed & input_embed
detatch_point_token: bool = field(default=False) # * deprecated
# * point backbone ckpt path
# point_backbone_ckpt: str = field(default='/home/pointllm_weight_2/PointLLM_7B_v1.2')
# point_backbone_ckpt: str = field(default="/home/R_Decoder/pretrain_weight/point_mae/pretrain.pth")
point_backbone_ckpt: str = field(default="/home/pointllm_weight_2/point_model/point_model.pth")
# point_backbone_ckpt: str = field(default="")
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {
key: value.cpu()
for key, value in state_dict.items()
}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def train():
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
training_args.log_level = "info" # * default is passive(warning)
# training_args.bf16 = True
# * build logger
training_args.output_dir = '/home/PointLLM/trash'
logger = build_logger(__name__, training_args.output_dir + '/train.log')
if training_args.model_debug:
# * do not load checkpoint, load from config
config = transformers.AutoConfig.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
torch_dtype=torch.float32
)
model = PointLLMLlamaForCausalLM._from_config(config)
else:
model = PointLLMLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
torch_dtype=torch.float16
)
model.config.use_cache = False
if training_args.fix_llm:
# * This will fix all the parameters
logger.info("LLM is fixed. Fix_llm flag is set to True")
# * fix llama, lm_head, pointnet, projection layer here
model.requires_grad_(False)
model.get_model().fix_llm = True
model.get_model().point_proj.requires_grad_(True)
model.get_model().point_backbone.requires_grad_(True) # * set as True for fsdp, use fix_pointnet flag to control
else:
model.get_model().fix_llm = False
logger.warning("LLM is trainable. Fix_llm flag is set to False")
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=True,
)
# tokenizer = transformers.AutoTokenizer.from_pretrained(
# model_args.model_name_or_path,
# padding_side="right",
# use_fast=False,
# )
if model_args.version == "v0" or "v0" in model_args.model_name_or_path:
raise ValueError("v0 is deprecated.")
else:
# tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token = tokenizer.eos_token
conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1_1"]
if not training_args.fix_pointnet:
# * not fix pointnet
logger.info("Point backbone is trainable. Fix_pointnet flag is set to False, pointnet grad will be recorded.")
model.get_model().fix_pointnet = False
else:
logger.info("Point backbone is fixed. Fix_pointnet flag is set to True, pointnet grad will not be recorded.")
model.get_model().fix_pointnet = True # * use with torch.inference_mode to control, not requires_grad for fsdp for second stage
if not training_args.stage_2:
logger.info("Set requires_grad of point backbone to False")
model.get_model().point_backbone.requires_grad_(False) # * fix pointnet for first stage, need for fsdp in stage2
if training_args.tune_mm_mlp_adapter:
# * not fix the projection layer
# * may need to set the embed_tokens to require_grad = True if added new tokens
# * this is done in initialize_tokenizer_point_backbone_config
logger.info("Point projection layer is trainable.")
else:
model.get_model().point_proj.requires_grad_(False)
logger.info("Point prejcetion layer is fixed.")
if not training_args.stage_2:
# * we assume in stage2, llm, point_backbone, and projection layer can be loaded from the model checkpoint
print(f"Default point_backbone_ckpt is {training_args.point_backbone_ckpt}.")
model.get_model().load_point_backbone_checkpoint(training_args.point_backbone_ckpt)
model.initialize_tokenizer_point_backbone_config(tokenizer=tokenizer, device=training_args.device, fix_llm=training_args.fix_llm)
else:
# * stage2
model.initialize_tokenizer_point_backbone_config_wo_embedding(tokenizer=tokenizer)
point_backbone_config = model.get_model().point_backbone_config
data_args.point_token_len = point_backbone_config['point_token_len']
data_args.mm_use_point_start_end = point_backbone_config['mm_use_point_start_end']
data_args.point_backbone_config = point_backbone_config
params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
if len(params_no_grad) > 0:
if training_args.fsdp is not None and len(training_args.fsdp) > 0:
if len(params_no_grad) < 10:
print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))
else:
print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10])))
print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.")
print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
def patch_FSDP_use_orig_params(func):
def wrap_func(*args, **kwargs):
use_orig_params = kwargs.pop('use_orig_params', True)
return func(*args, **kwargs, use_orig_params=use_orig_params)
return wrap_func
FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)
data_module = make_object_point_data_module(tokenizer=tokenizer,
data_args=data_args)
# for name, param in model.model.layers.named_parameters():
# param.requires_grad = False
#
# for name, param in model.model.point_backbone.named_parameters():
# param.requires_grad = False
for name, param in model.model.named_parameters():
param.requires_grad = False
for name, param in model.model.named_parameters():
if 'point_proj' in name:
param.requires_grad = True
for name, param in model.model.named_parameters():
if 'q_layernorm' in name:
param.requires_grad = True
if 'k_layernorm' in name:
param.requires_grad = True
if 'post_layernorm' in name:
param.requires_grad = True
if 'input_layernorm' in name:
param.requires_grad = True
# if 'input_layernorm' in name:
# param.requires_grad = True
if 'final_layernorm' in name:
param.requires_grad = True
############################
#
#
# for name, param in model.model.layers.named_parameters():
# param.requires_grad = False
#
# # for i, layer in enumerate(llama_model.model.layers):
# # # 如果层的索引小于5,则将该层的参数设置为可训练
# # if i < 5:
# # for param in layer.parameters():
# # param.requires_grad = True
# # # 将这些层的参数转换为FP32
# # layer.to(torch.float32)
for i, layer in enumerate(model.model.layers):
# layer.register_forward_hook(print_layer_output)
# set trainable to True for the input_layernorm layer
layer.self_attn.q_layernorm.weight.requires_grad = True
layer.self_attn.k_layernorm.weight.requires_grad = True
layer.post_layernorm.weight.requires_grad = True
layer.input_layernorm.weight.requires_grad = True
layer.self_attn.q_layernorm.weight.data = layer.self_attn.q_layernorm.weight.data.float()
layer.self_attn.k_layernorm.weight.data = layer.self_attn.k_layernorm.weight.data.float()
layer.post_layernorm.weight.data = layer.post_layernorm.weight.data.float()
layer.input_layernorm.weight.data = layer.input_layernorm.weight.data.float()
# 对偏置项进行类似操作
if layer.self_attn.q_layernorm.bias is not None:
layer.self_attn.q_layernorm.bias.data = layer.self_attn.q_layernorm.bias.data.float()
if layer.self_attn.k_layernorm.bias is not None:
layer.self_attn.k_layernorm.bias.data = layer.self_attn.k_layernorm.bias.data.float()
if layer.input_layernorm.bias is not None:
layer.input_layernorm.bias.data = layer.input_layernorm.bias.data.float()
model.model.final_layernorm.weight.requires_grad = True
model.model.final_layernorm.weight.data = model.model.final_layernorm.weight.data.float()
if model.model.final_layernorm.bias is not None:
model.model.final_layernorm.bias.data = model.model.final_layernorm.bias.float()
###################################
for name, param in model.model.named_parameters():
if param.requires_grad: # 如果参数需要梯度,那么它将被更新
logger.info(f"Parameter {name} will be updated.")
# import os
# torch.save({
# 'base_model': model.model.point_backbone.state_dict(), }, os.path.join('/home/pointllm_weight_2/point_model/point_model.pth'))
# model = model.half()
trainer = PointLLMTrainer(model=model,
tokenizer=tokenizer,
args=training_args,
**data_module)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
safe_save_model_for_hf_trainer(trainer=trainer,
output_dir=training_args.output_dir)
if __name__ == "__main__":
train()
|