multimodalart HF Staff commited on
Commit
7702d69
·
verified ·
1 Parent(s): 2d97c45

Create train.py

Browse files
Files changed (1) hide show
  1. blip3o/train/train.py +1025 -0
blip3o/train/train.py ADDED
@@ -0,0 +1,1025 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import copy
4
+ from dataclasses import dataclass, field
5
+ import json
6
+ import logging
7
+ import pathlib
8
+ from typing import Dict, Optional, Sequence, List
9
+ import time
10
+ import torch, gc
11
+ import glob
12
+ import transformers
13
+ import tokenizers
14
+ import random
15
+ from blip3o.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_IDX
16
+ from torch.utils.data import Dataset
17
+ from blip3o.train.blip3o_trainer import blip3oTrainer
18
+ from blip3o import conversation as conversation_lib
19
+ from blip3o.model import *
20
+ from blip3o.mm_utils import tokenizer_image_token
21
+ from PIL import Image, ImageFile
22
+ from datasets import load_dataset, concatenate_datasets
23
+ from pathlib import Path
24
+ from datasets.utils.logging import set_verbosity_info
25
+ from transformers import logging as tf_logging
26
+ import torchvision.transforms as T
27
+ from torchvision.transforms.functional import InterpolationMode
28
+ from transformers import AutoProcessor
29
+
30
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
31
+ transform_und_images = T.Compose([T.Resize(448, interpolation=InterpolationMode.BICUBIC, antialias=True), T.CenterCrop(448)])
32
+
33
+ set_verbosity_info()
34
+ tf_logging.set_verbosity_info()
35
+
36
+ local_rank = None
37
+
38
+
39
+
40
+
41
+ def rank0_print(*args):
42
+ if local_rank == 0:
43
+ print(*args)
44
+
45
+
46
+ from packaging import version
47
+
48
+ IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse("0.14")
49
+
50
+
51
+ @dataclass
52
+ class ModelArguments:
53
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
54
+ version: Optional[str] = field(default="v0")
55
+ freeze_backbone: bool = field(default=True)
56
+ tune_mm_mlp_adapter: bool = field(default=False)
57
+ vision_tower: Optional[str] = field(default=None)
58
+ gen_vision_tower: Optional[str] = field(default=None)
59
+ mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
60
+ pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
61
+ pretrain_gen_mlp_adapter: Optional[str] = field(default=None)
62
+ vision_tower_pretrained: Optional[str] = field(default=None)
63
+ mm_projector_type: Optional[str] = field(default="linear")
64
+ gen_projector_type: Optional[str] = field(default="linear")
65
+ mm_use_im_start_end: bool = field(default=False)
66
+ mm_use_im_patch_token: bool = field(default=True)
67
+ mm_patch_merge_type: Optional[str] = field(default="flat")
68
+ mm_vision_select_feature: Optional[str] = field(default="patch")
69
+ n_query: Optional[int] = field(default=729) # clip 576, siglip 729
70
+ n_und_query: Optional[int] = field(default=729) # clip 576, siglip 729
71
+ gen_pooling: Optional[str] = field(default="all") # options are: pool2d_3, pool2d_9, seq_3, seq_9, seq_27
72
+
73
+
74
+ @dataclass
75
+ class DataArguments:
76
+ data_path: str = field(default=None, metadata={"help": "Path to the training data."})
77
+ lazy_preprocess: bool = False
78
+ is_multimodal: bool = False
79
+ image_folder: Optional[str] = field(default=None)
80
+ shortcaption_image_folder: Optional[str] = field(default=None)
81
+ data_type: Optional[str] = field(default="mix")
82
+ image_aspect_ratio: str = "square"
83
+
84
+
85
+ @dataclass
86
+ class TrainingArguments(transformers.TrainingArguments):
87
+ cache_dir: Optional[str] = field(default=None)
88
+ optim: str = field(default="adamw_torch")
89
+ remove_unused_columns: bool = field(default=False)
90
+ freeze_mm_mlp_adapter: bool = field(default=False)
91
+ mpt_attn_impl: Optional[str] = field(default="triton")
92
+ model_max_length: int = field(
93
+ default=512,
94
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
95
+ )
96
+ double_quant: bool = field(
97
+ default=True,
98
+ metadata={"help": "Compress the quantization statistics through double quantization."},
99
+ )
100
+ quant_type: str = field(
101
+ default="nf4",
102
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."},
103
+ )
104
+ bits: int = field(default=16, metadata={"help": "How many bits to use."})
105
+ lora_enable: bool = False
106
+ lora_r: int = 64
107
+ lora_alpha: int = 16
108
+ lora_dropout: float = 0.05
109
+ lora_weight_path: str = ""
110
+ lora_bias: str = "none"
111
+ mm_projector_lr: Optional[float] = None
112
+ group_by_modality_length: bool = field(default=False)
113
+
114
+
115
+ def maybe_zero_3(param, ignore_status=False, name=None):
116
+ from deepspeed import zero
117
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
118
+
119
+ if hasattr(param, "ds_id"):
120
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
121
+ if not ignore_status:
122
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
123
+ with zero.GatheredParameters([param]):
124
+ param = param.data.detach().cpu().clone()
125
+ else:
126
+ param = param.detach().cpu().clone()
127
+ return param
128
+
129
+
130
+ # Borrowed from peft.utils.get_peft_model_state_dict
131
+ def get_peft_state_maybe_zero_3(named_params, bias):
132
+ if bias == "none":
133
+ to_return = {k: t for k, t in named_params if "lora_" in k}
134
+ elif bias == "all":
135
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
136
+ elif bias == "lora_only":
137
+ to_return = {}
138
+ maybe_lora_bias = {}
139
+ lora_bias_names = set()
140
+ for k, t in named_params:
141
+ if "lora_" in k:
142
+ to_return[k] = t
143
+ bias_name = k.split("lora_")[0] + "bias"
144
+ lora_bias_names.add(bias_name)
145
+ elif "bias" in k:
146
+ maybe_lora_bias[k] = t
147
+ for k, t in maybe_lora_bias:
148
+ if bias_name in lora_bias_names:
149
+ to_return[bias_name] = t
150
+ else:
151
+ raise NotImplementedError
152
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
153
+ return to_return
154
+
155
+
156
+ def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
157
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
158
+ if require_grad_only:
159
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
160
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
161
+ return to_return
162
+
163
+
164
+ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
165
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
166
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
167
+ return to_return
168
+
169
+
170
+ def get_vision_tower_state_maybe_zero_3(named_params, keys_to_match=[""]):
171
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
172
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
173
+ return to_return
174
+
175
+
176
+ def find_all_linear_names(model):
177
+ cls = torch.nn.Linear
178
+ lora_module_names = set()
179
+ multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"]
180
+ for name, module in model.named_modules():
181
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
182
+ continue
183
+ if isinstance(module, cls):
184
+ names = name.split(".")
185
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
186
+
187
+ if "lm_head" in lora_module_names: # needed for 16-bit
188
+ lora_module_names.remove("lm_head")
189
+ return list(lora_module_names)
190
+
191
+
192
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, vision_tower: str):
193
+ """Collects the state dict and dump to disk."""
194
+
195
+ # if getattr(trainer.args, "tune_vision_model", False):
196
+
197
+ if trainer.deepspeed:
198
+ torch.cuda.synchronize()
199
+
200
+
201
+ # Only save Adapter
202
+ keys_to_match = ["mm_projector"]
203
+ if getattr(trainer.args, "use_im_start_end", False):
204
+ keys_to_match.extend(["embed_tokens", "embed_in"])
205
+
206
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
207
+ trainer.model.config.save_pretrained(output_dir)
208
+
209
+ current_folder = output_dir.split("/")[-1]
210
+ parent_folder = os.path.dirname(output_dir)
211
+ if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
212
+ if current_folder.startswith("checkpoint-"):
213
+ mm_projector_folder = os.path.join(parent_folder, "mm_projector")
214
+ os.makedirs(mm_projector_folder, exist_ok=True)
215
+ torch.save(
216
+ weight_to_save,
217
+ os.path.join(mm_projector_folder, f"{current_folder}.bin"),
218
+ )
219
+ else:
220
+ torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin"))
221
+
222
+ keys_to_match = ["gen_projector"]
223
+ if getattr(trainer.args, "use_im_start_end", False):
224
+ keys_to_match.extend(["embed_tokens", "embed_in"])
225
+
226
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
227
+ trainer.model.config.save_pretrained(output_dir)
228
+
229
+ current_folder = output_dir.split("/")[-1]
230
+ parent_folder = os.path.dirname(output_dir)
231
+ if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
232
+ if current_folder.startswith("checkpoint-"):
233
+ mm_projector_folder = os.path.join(parent_folder, "gen_projector")
234
+ os.makedirs(mm_projector_folder, exist_ok=True)
235
+ torch.save(
236
+ weight_to_save,
237
+ os.path.join(mm_projector_folder, f"{current_folder}.bin"),
238
+ )
239
+ else:
240
+ torch.save(weight_to_save, os.path.join(output_dir, f"gen_projector.bin"))
241
+
242
+ if trainer.deepspeed:
243
+ torch.cuda.synchronize()
244
+ trainer.save_model(output_dir)
245
+ return
246
+
247
+ state_dict = trainer.model.state_dict()
248
+ if trainer.args.should_save:
249
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
250
+ del state_dict
251
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
252
+
253
+
254
+ def smart_tokenizer_and_embedding_resize(
255
+ special_tokens_dict: Dict,
256
+ tokenizer: transformers.PreTrainedTokenizer,
257
+ model: transformers.PreTrainedModel,
258
+ ):
259
+
260
+
261
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
262
+ model.resize_token_embeddings(len(tokenizer))
263
+
264
+ if num_new_tokens > 0:
265
+ input_embeddings = model.get_input_embeddings().weight.data
266
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
267
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
268
+
269
+
270
+ def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
271
+ """Tokenize a list of strings."""
272
+ tokenized_list = [
273
+ tokenizer(
274
+ text,
275
+ return_tensors="pt",
276
+ padding="longest",
277
+ max_length=tokenizer.model_max_length,
278
+ truncation=True,
279
+ )
280
+ for text in strings
281
+ ]
282
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
283
+ input_ids_lens = labels_lens = [tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list]
284
+ return dict(
285
+ input_ids=input_ids,
286
+ labels=labels,
287
+ input_ids_lens=input_ids_lens,
288
+ labels_lens=labels_lens,
289
+ )
290
+
291
+
292
+ def _mask_targets(target, tokenized_lens, speakers):
293
+ # cur_idx = 0
294
+ cur_idx = tokenized_lens[0]
295
+ tokenized_lens = tokenized_lens[1:]
296
+ target[:cur_idx] = IGNORE_INDEX
297
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
298
+ if speaker == "human":
299
+ target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX
300
+ cur_idx += tokenized_len
301
+
302
+
303
+ def _add_speaker_and_signal(header, source, get_conversation=True):
304
+ """Add speaker and start/end signal on each round."""
305
+ BEGIN_SIGNAL = "### "
306
+ END_SIGNAL = "\n"
307
+ conversation = header
308
+ for sentence in source:
309
+ from_str = sentence["from"]
310
+ if from_str.lower() == "human":
311
+ from_str = conversation_lib.default_conversation.roles[0]
312
+ elif from_str.lower() == "gpt":
313
+ from_str = conversation_lib.default_conversation.roles[1]
314
+ else:
315
+ from_str = "unknown"
316
+ sentence["value"] = BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL
317
+ if get_conversation:
318
+ conversation += sentence["value"]
319
+ conversation += BEGIN_SIGNAL
320
+ return conversation
321
+
322
+
323
+
324
+ def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict:
325
+ is_multimodal = data_args.is_multimodal
326
+ if not is_multimodal:
327
+ return sources
328
+ und_placeholder = "<|vision_start|>" + "<|image_pad|>" * data_args.n_und_query + "<|vision_end|>"
329
+ gen_placeholder = ""
330
+ # "[IMG]" + "<image>" * data_args.n_query + "[/IMG]"
331
+ inst_type = None
332
+ for source in sources: # [instance]
333
+ for sentence in source:
334
+ if sentence["from"] == "human" and "<image>" in sentence["value"]:
335
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, und_placeholder).strip()
336
+ inst_type = "und"
337
+ elif sentence["from"] == "gpt" and "<image>" in sentence["value"]:
338
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, gen_placeholder).strip()
339
+ inst_type = "gen"
340
+ return sources, inst_type
341
+
342
+
343
+
344
+
345
+
346
+ def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
347
+ roles = {"human": "user", "gpt": "assistant"}
348
+
349
+ tokenizer = copy.deepcopy(tokenizer)
350
+ chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
351
+ tokenizer.chat_template = chat_template
352
+
353
+ # Apply prompt templates
354
+ input_ids, targets = [], []
355
+ for i, source in enumerate(sources):
356
+ if roles[source[0]["from"]] != roles["human"]:
357
+ source = source[1:]
358
+
359
+ input_id, target = [], []
360
+
361
+ # New version, use apply chat template
362
+ # Build system message for each sentence
363
+ input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}])
364
+ target += [IGNORE_INDEX] * len(input_id)
365
+
366
+ for conv in source:
367
+ try:
368
+ role = conv["role"]
369
+ content = conv["content"]
370
+ except:
371
+ role = conv["from"]
372
+ content = conv["value"]
373
+
374
+ role = roles.get(role, role)
375
+
376
+ conv = [{"role" : role, "content" : content}]
377
+ encode_id = tokenizer.apply_chat_template(conv)
378
+ input_id += encode_id
379
+ if role in ["user", "system"]:
380
+ target += [IGNORE_INDEX] * len(encode_id)
381
+ else:
382
+ target += encode_id
383
+
384
+
385
+
386
+ assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
387
+
388
+ input_ids.append(input_id)
389
+ targets.append(target)
390
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
391
+ targets = torch.tensor(targets, dtype=torch.long)
392
+
393
+ return dict(
394
+ input_ids=input_ids, # tensor(bs x seq_len)
395
+ labels=targets, # tensor(bs x seq_len)
396
+ )
397
+
398
+
399
+
400
+
401
+ def preprocess_llama3(
402
+ sources,
403
+ tokenizer: transformers.PreTrainedTokenizer,
404
+ has_image: bool = False,
405
+ max_len=2048,
406
+ system_message: str = "You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
407
+ ) -> Dict:
408
+ # roles = {"human": "<|start_header_id|>user<|end_header_id|>", "gpt": "<|start_header_id|>assistant<|end_header_id|>"}
409
+ roles = {"human": "user", "gpt": "assistant"}
410
+
411
+ # Add image tokens to tokenizer as a special tokens
412
+ # Use a deepcopy of tokenizer so that we don't modify on the tokenizer
413
+ tokenizer = copy.deepcopy(tokenizer)
414
+ # When there is actually an image, we add the image tokens as a special token
415
+ if has_image:
416
+ tokenizer.add_tokens(["<image>"], special_tokens=True)
417
+ image_token_index = tokenizer.convert_tokens_to_ids("<image>")
418
+ bos_token_id = tokenizer.convert_tokens_to_ids("<|begin_of_text|>")
419
+ start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>")
420
+ end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
421
+ eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
422
+
423
+ unmask_tokens = ["<|begin_of_text|>", "<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>", "\n\n"]
424
+ unmask_tokens_idx = [tokenizer.convert_tokens_to_ids(tok) for tok in unmask_tokens]
425
+
426
+ # After update, calling tokenizer of llama3 will
427
+ # auto add bos id for the tokens. ヽ(`⌒´)ノ
428
+ def safe_tokenizer_llama3(text):
429
+ input_ids = tokenizer(text).input_ids
430
+ if input_ids[0] == bos_token_id:
431
+ input_ids = input_ids[1:]
432
+ return input_ids
433
+
434
+ nl_tokens = tokenizer.convert_tokens_to_ids("\n\n")
435
+ # Apply prompt templates
436
+ input_ids, targets = [], []
437
+ for i, source in enumerate(sources):
438
+ if roles[source[0]["from"]] != roles["human"]:
439
+ source = source[1:]
440
+
441
+ input_id, target = [], []
442
+
443
+ # New version, use apply chat template
444
+ # Build system message for each sentence
445
+ input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}])
446
+ target += [IGNORE_INDEX] * len(input_id)
447
+
448
+ for conv in source:
449
+ try:
450
+ role = conv["role"]
451
+ content = conv["content"]
452
+ except:
453
+ role = conv["from"]
454
+ content = conv["value"]
455
+
456
+ role = roles.get(role, role)
457
+
458
+ conv = [{"role" : role, "content" : content}]
459
+ # First is bos token we don't need here
460
+ encode_id = tokenizer.apply_chat_template(conv)[1:]
461
+ input_id += encode_id
462
+ if role in ["user", "system"]:
463
+ target += [IGNORE_INDEX] * len(encode_id)
464
+ else:
465
+ target += encode_id
466
+
467
+
468
+
469
+ assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
470
+ for idx, encode_id in enumerate(input_id):
471
+ if encode_id in unmask_tokens_idx:
472
+ target[idx] = encode_id
473
+ if encode_id == image_token_index:
474
+ input_id[idx] = IMAGE_TOKEN_INDEX
475
+ input_ids.append(input_id)
476
+ targets.append(target)
477
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
478
+ targets = torch.tensor(targets, dtype=torch.long)
479
+
480
+ return dict(
481
+ input_ids=input_ids, # tensor(bs x seq_len)
482
+ labels=targets, # tensor(bs x seq_len)
483
+ )
484
+
485
+
486
+
487
+ def preprocess_plain(
488
+ sources: Sequence[str],
489
+ tokenizer: transformers.PreTrainedTokenizer,
490
+ ) -> Dict:
491
+ # add end signal and concatenate together
492
+ conversations = []
493
+ for source in sources:
494
+ assert len(source) == 2
495
+ # assert DEFAULT_IMAGE_TOKEN in source[0]['value'] or DEFAULT_IMAGE_TOKEN in source[1]['value']
496
+ conversation = source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep
497
+ conversations.append(conversation)
498
+ # tokenize conversations
499
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations]
500
+ targets = copy.deepcopy(input_ids)
501
+ for target, source in zip(targets, sources):
502
+ tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
503
+ target[:tokenized_len] = IGNORE_INDEX
504
+
505
+ return dict(input_ids=input_ids, labels=targets)
506
+
507
+
508
+ def preprocess(
509
+ sources: Sequence[str],
510
+ tokenizer: transformers.PreTrainedTokenizer,
511
+ has_image: bool = False,
512
+ ) -> Dict:
513
+ """
514
+ Given a list of sources, each is a conversation list. This transform:
515
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
516
+ 2. Concatenate conversations together;
517
+ 3. Tokenize the concatenated conversation;
518
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
519
+ """
520
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
521
+ return preprocess_plain(sources, tokenizer)
522
+ if conversation_lib.default_conversation.version == "llama3":
523
+ return preprocess_llama3(sources, tokenizer, has_image=has_image)
524
+ if conversation_lib.default_conversation.version == "qwen":
525
+ return preprocess_qwen(sources, tokenizer, has_image=has_image)
526
+ # add end signal and concatenate together
527
+ conversations = []
528
+ for source in sources:
529
+ header = f"{conversation_lib.default_conversation.system}\n\n"
530
+ conversation = _add_speaker_and_signal(header, source)
531
+ conversations.append(conversation)
532
+
533
+ # tokenize conversations
534
+ def get_tokenize_len(prompts):
535
+ return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
536
+
537
+ if has_image:
538
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations]
539
+ else:
540
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
541
+ input_ids = conversations_tokenized["input_ids"]
542
+
543
+ targets = copy.deepcopy(input_ids)
544
+ for target, source in zip(targets, sources):
545
+ if has_image:
546
+ tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
547
+ else:
548
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
549
+ speakers = [sentence["from"] for sentence in source]
550
+ _mask_targets(target, tokenized_lens, speakers)
551
+
552
+ return dict(input_ids=input_ids, labels=targets)
553
+
554
+
555
+
556
+ class LazySupervisedMixDataset(Dataset):
557
+ """Dataset for supervised fine-tuning."""
558
+
559
+ def __init__(
560
+ self,
561
+ data_path: str,
562
+ tokenizer: transformers.PreTrainedTokenizer,
563
+ data_args: DataArguments,
564
+ ):
565
+ super(LazySupervisedMixDataset, self).__init__()
566
+
567
+ self.data_args = data_args
568
+ list_data_dict = []
569
+
570
+
571
+ ###################################### text to image #######################################
572
+ data_files = glob.glob(os.path.join(self.data_args.image_folder, "*.tar"))
573
+ ## text to image
574
+ train_dataset = load_dataset("webdataset", data_files=data_files, split="train", num_proc=128)
575
+ train_dataset = train_dataset.rename_column("jpg", "image")
576
+ train_dataset = train_dataset.add_column('type', len(train_dataset) * ['T2I'])
577
+ train_dataset = train_dataset.add_column('image_path', len(train_dataset) * [None])
578
+ train_dataset = train_dataset.remove_columns([col for col in train_dataset.column_names if not col in (
579
+ ["image", "txt", "type", "image_path"])])
580
+ print(f"finish loading image {len(train_dataset)}")
581
+ list_data_dict.append(train_dataset)
582
+
583
+
584
+ if len(list_data_dict) > 1:
585
+ list_data_dict = concatenate_datasets(list_data_dict)
586
+ else:
587
+ list_data_dict = list_data_dict[0]
588
+ list_data_dict = list_data_dict.shuffle(seed=42)
589
+
590
+ rank0_print(f"Totoal number of training instance: {len(list_data_dict)}")
591
+ self.tokenizer = tokenizer
592
+ self.list_data_dict = list_data_dict
593
+
594
+ def __len__(self):
595
+ return len(self.list_data_dict)
596
+
597
+ @property
598
+ def lengths(self):
599
+ length_list = []
600
+ for sample in self.list_data_dict:
601
+ img_tokens = 128 if "image" in sample else 0
602
+ length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
603
+ return length_list
604
+
605
+ @property
606
+ def modality_lengths(self):
607
+ length_list = []
608
+ for sample in self.list_data_dict:
609
+ cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"])
610
+ cur_len = cur_len if "image" in sample else -cur_len
611
+ length_list.append(cur_len)
612
+ return length_list
613
+
614
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
615
+
616
+ while True:
617
+ sources = self.list_data_dict[i]
618
+
619
+ if sources["type"] == "T2I" or sources["type"] == "journeyDB_T2I":
620
+ sources["conversations"] = [
621
+ {"from": "human", "value": f"Please generate image based on the following caption: {sources['txt']}"},
622
+ {"from": "gpt", "value": "<image>"},
623
+ ]
624
+
625
+
626
+ elif sources["type"] == "I2I" or sources["type"] == "journeyDB_I2I":
627
+ sources["conversations"] = [
628
+ {
629
+ "from": "human",
630
+ "value": f"<image>\nPlease reconstruct the given image.",
631
+ },
632
+ {"from": "gpt", "value": ""},
633
+ ]
634
+
635
+ else:
636
+ raise ValueError("Unknown source type. Please check the 'type' in 'sources'.")
637
+
638
+ if "image" in sources:
639
+
640
+ def img_process(images, processor, image_aspect_ratio):
641
+ if image_aspect_ratio == "pad":
642
+
643
+ def expand2square(pil_img, background_color):
644
+ width, height = pil_img.size
645
+ if width == height:
646
+ return pil_img
647
+ elif width > height:
648
+ result = Image.new(pil_img.mode, (width, width), background_color)
649
+ result.paste(pil_img, (0, (width - height) // 2))
650
+ return result
651
+ else:
652
+ result = Image.new(pil_img.mode, (height, height), background_color)
653
+ result.paste(pil_img, ((height - width) // 2, 0))
654
+ return result
655
+
656
+ images = [expand2square(img, tuple(int(x * 255) for x in processor.image_mean)) for img in images]
657
+ images = processor.preprocess(images, return_tensors="pt")["pixel_values"]
658
+ else:
659
+ images = processor.preprocess(images, return_tensors="pt")["pixel_values"]
660
+ return images
661
+
662
+ if sources["type"] == "T2I" or sources["type"] == "I2I":
663
+ image_files = self.list_data_dict[i]["image"]
664
+ else:
665
+ image_files = self.list_data_dict[i]["image_path"]
666
+
667
+ if not isinstance(image_files, list):
668
+ image_files = [image_files]
669
+
670
+ images = []
671
+
672
+ def read_bin_as_bytesio(bin_file_path):
673
+ with open(bin_file_path, "rb") as f:
674
+ return io.BytesIO(f.read())
675
+
676
+ for img in image_files:
677
+ try:
678
+ if sources["type"] == "T2I" or sources["type"] == "I2I":
679
+ img = img.convert("RGB")
680
+ elif sources["type"] == "journeyDB_T2I" or sources["type"] == "journeyDB_I2I":
681
+ if sources["type"] == "journeyDB_T2I" or sources["type"] == "journeyDB_I2I":
682
+ image_path = os.path.join('/fsx/sfr/data/jiuhai/hub/datasets--JourneyDB--JourneyDB/snapshots/e191aa61ca37e5e4418707ade4df5deb5c6d5d8f/data/train/imgs', img)
683
+ else:
684
+ raise ValueError("Unknown source type. Please check the 'type' in 'sources'.")
685
+ img = Image.open(image_path).convert("RGB")
686
+ images.append(img)
687
+ except Exception as e:
688
+ print(f"Error opening image {img}: {e}")
689
+ images = None
690
+ break # Skip to the next image if there's an error
691
+
692
+ if not images is None:
693
+ try:
694
+ temp = img_process(
695
+ images,
696
+ self.data_args.gen_image_processor,
697
+ self.data_args.image_aspect_ratio,
698
+ )
699
+ except Exception as e:
700
+ print(f"Error wrong number of channels: {e}")
701
+ images = None
702
+
703
+
704
+ # If no valid images were found, randomly pick another item
705
+ if images is None:
706
+ print(sources)
707
+ print(f"warning false image!!!!!!")
708
+ i = random.randint(0, len(self.list_data_dict) - 1)
709
+ continue
710
+
711
+
712
+ sources, inst_type = preprocess_multimodal(copy.deepcopy([sources["conversations"]]), self.data_args)
713
+ else:
714
+ sources = copy.deepcopy([sources["conversations"]])
715
+ data_dict = preprocess(sources, self.tokenizer, has_image=("image" in self.list_data_dict[i]))
716
+ if isinstance(i, int):
717
+ data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
718
+
719
+ # image exist in the data
720
+ if "image" in self.list_data_dict[i]:
721
+ if inst_type == "gen":
722
+ data_dict["gen_image"] = img_process(
723
+ images,
724
+ self.data_args.gen_image_processor,
725
+ self.data_args.image_aspect_ratio,
726
+ )
727
+
728
+ elif inst_type == "und":
729
+
730
+ resized_images = [transform_und_images(img) for img in images]
731
+
732
+ image_inputs = self.data_args.image_processor(resized_images, return_tensors="pt")
733
+
734
+ data_dict["und_image"] = image_inputs.pixel_values
735
+ data_dict["grid_thw"] = image_inputs.image_grid_thw
736
+ data_dict["gen_image"] = img_process(
737
+ resized_images,
738
+ self.data_args.gen_image_processor,
739
+ self.data_args.image_aspect_ratio,
740
+ )
741
+
742
+ elif self.data_args.is_multimodal:
743
+ crop_size = self.data_args.image_processor.crop_size
744
+ data_dict["image"] = torch.zeros(3, crop_size["height"], crop_size["width"])
745
+
746
+ data_dict["ids"] = self.list_data_dict[i]["id"] if "id" in self.list_data_dict[i] else "unk"
747
+ return data_dict
748
+
749
+
750
+ @dataclass
751
+ class DataCollatorForSupervisedDataset(object):
752
+ """Collate examples for supervised fine-tuning."""
753
+
754
+ tokenizer: transformers.PreTrainedTokenizer
755
+
756
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
757
+ input_ids, labels, ids = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "ids"))
758
+ multi_input_ids = []
759
+ multi_labels = []
760
+ i_s_pos = []
761
+ for input_id, label in zip(input_ids, labels):
762
+ input_id = input_id[: self.tokenizer.model_max_length - 65]
763
+ label = label[: self.tokenizer.model_max_length - 65]
764
+ i_s_pos.append(input_id.shape[0]+1)
765
+ img_id = torch.full((65,), IMAGE_TOKEN_IDX, dtype=input_id.dtype, device=input_id.device)
766
+ img_id[0] = 151665
767
+ input_id = torch.cat([input_id, img_id])
768
+ img_label = torch.full((65,), IMAGE_TOKEN_IDX, dtype=label.dtype, device=label.device)
769
+ img_label[0] = 151665
770
+ label = torch.cat([label, img_label])
771
+ multi_input_ids.append(input_id)
772
+ multi_labels.append(label)
773
+
774
+ input_ids = multi_input_ids
775
+ labels = multi_labels
776
+
777
+ input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
778
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
779
+ if input_ids.shape[1] > self.tokenizer.model_max_length:
780
+ print(f"Warning input with length {input_ids.shape[1]} is longer than max length {self.tokenizer.model_max_length}")
781
+ input_ids = input_ids[:, : self.tokenizer.model_max_length]
782
+ labels = labels[:, : self.tokenizer.model_max_length]
783
+ batch = dict(
784
+ input_ids=input_ids,
785
+ labels=labels,
786
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
787
+ )
788
+
789
+ batch_gen_images = []
790
+ batch_und_images = []
791
+ batch_grid_thw = []
792
+
793
+ for instance in instances:
794
+ if "gen_image" in instance:
795
+ batch_gen_images.append(instance["gen_image"])
796
+
797
+
798
+ if len(batch_gen_images) > 0:
799
+ if all(x is not None and y.shape == batch_gen_images[0][0].shape for x in batch_gen_images for y in x):
800
+ batch["gen_image"] = torch.cat([images for images in batch_gen_images], dim=0)
801
+ else:
802
+ batch["gen_image"] = batch_gen_images
803
+ else:
804
+ batch["gen_image"] = None
805
+
806
+
807
+ for instance in instances:
808
+ if "und_image" in instance:
809
+ batch_und_images.append(instance["und_image"].unsqueeze(0)) ## 1*1024*1176
810
+ batch_grid_thw.append(instance["grid_thw"]) ## 1*3
811
+
812
+
813
+ # print(f"batch_und_images {batch_und_images}")
814
+ if len(batch_und_images) > 0:
815
+ batch["und_image"] = torch.cat([images for images in batch_und_images], dim=0)
816
+ batch["grid_thw"] = torch.cat([images for images in batch_grid_thw], dim=0)
817
+ else:
818
+ batch["und_image"] = None
819
+ batch["grid_thw"] = None
820
+
821
+ batch["ids"] = ids
822
+
823
+ batch["i_s_pos"] = i_s_pos
824
+
825
+ return batch
826
+
827
+
828
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
829
+
830
+ if data_args.data_type == "mix":
831
+ train_dataset = LazySupervisedMixDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args)
832
+ else:
833
+ raise ValueError("Unknown data type. Please check the Dataloader type.")
834
+
835
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
836
+ return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
837
+
838
+
839
+ def unlock_vit(training_args, model_args, vision_tower):
840
+ for n, p in vision_tower.named_parameters():
841
+ p.requires_grad = True
842
+
843
+
844
+ def train(attn_implementation=None):
845
+ global local_rank
846
+
847
+ parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
848
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
849
+ print(model_args, data_args, training_args)
850
+ local_rank = training_args.local_rank
851
+ compute_dtype = torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)
852
+
853
+ bnb_model_from_pretrained_args = {}
854
+ if training_args.bits in [4, 8]:
855
+ from transformers import BitsAndBytesConfig
856
+
857
+ bnb_model_from_pretrained_args.update(
858
+ dict(
859
+ device_map={"": training_args.device},
860
+ load_in_4bit=training_args.bits == 4,
861
+ load_in_8bit=training_args.bits == 8,
862
+ quantization_config=BitsAndBytesConfig(
863
+ load_in_4bit=training_args.bits == 4,
864
+ load_in_8bit=training_args.bits == 8,
865
+ llm_int8_skip_modules=["mm_projector"],
866
+ llm_int8_threshold=6.0,
867
+ llm_int8_has_fp16_weight=False,
868
+ bnb_4bit_compute_dtype=compute_dtype,
869
+ bnb_4bit_use_double_quant=training_args.double_quant,
870
+ bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'}
871
+ ),
872
+ )
873
+ )
874
+
875
+ if model_args.vision_tower is not None:
876
+ model = blip3oLlamaForCausalLM.from_pretrained(
877
+ model_args.model_name_or_path,
878
+ cache_dir=training_args.cache_dir,
879
+ attn_implementation=attn_implementation,
880
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
881
+ **bnb_model_from_pretrained_args,
882
+ )
883
+ else:
884
+ if "Qwen" in model_args.model_name_or_path or "qwen" in model_args.model_name_or_path :
885
+ model = blip3oQwenForCausalLM.from_pretrained(
886
+ model_args.model_name_or_path,
887
+ cache_dir=training_args.cache_dir,
888
+ attn_implementation=attn_implementation,
889
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
890
+ **bnb_model_from_pretrained_args,
891
+ )
892
+ else:
893
+ model = transformers.LlamaForCausalLM.from_pretrained(
894
+ model_args.model_name_or_path,
895
+ cache_dir=training_args.cache_dir,
896
+ attn_implementation=attn_implementation,
897
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
898
+ **bnb_model_from_pretrained_args,
899
+ )
900
+ model.config.use_cache = False
901
+
902
+ if model_args.freeze_backbone:
903
+ for (n, p) in model.get_model().named_parameters():
904
+ p.requires_grad = False
905
+ for (n, p) in model.visual.named_parameters():
906
+ p.requires_grad = False
907
+ for (n, p) in model.lm_head.named_parameters():
908
+ p.requires_grad = False
909
+
910
+ if training_args.gradient_checkpointing:
911
+ if hasattr(model, "enable_input_require_grads"):
912
+ model.enable_input_require_grads()
913
+ else:
914
+
915
+ def make_inputs_require_grad(module, input, output):
916
+ output.requires_grad_(True)
917
+
918
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
919
+ if "Qwen" in model_args.model_name_or_path or "qwen" in model_args.model_name_or_path:
920
+ tokenizer = AutoProcessor.from_pretrained(model_args.model_name_or_path).tokenizer
921
+ tokenizer.model_max_length = training_args.model_max_length
922
+ else:
923
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
924
+ model_args.model_name_or_path,
925
+ cache_dir=training_args.cache_dir,
926
+ model_max_length=training_args.model_max_length,
927
+ padding_side="right",
928
+ use_fast=False,
929
+ )
930
+ # tokenizer.pad_token = tokenizer.unk_token
931
+ if tokenizer.pad_token is None:
932
+ smart_tokenizer_and_embedding_resize(
933
+ special_tokens_dict=dict(
934
+ pad_token="<pad>",
935
+ additional_special_tokens=["[IMG]", "[/IMG]", "<image>"],
936
+ ),
937
+ tokenizer=tokenizer,
938
+ model=model,
939
+ )
940
+ elif not "<image>" in tokenizer.get_added_vocab():
941
+ smart_tokenizer_and_embedding_resize(
942
+ special_tokens_dict=dict(additional_special_tokens=["[IMG]", "[/IMG]", "<image>"]),
943
+ tokenizer=tokenizer,
944
+ model=model,
945
+ )
946
+ if model_args.version in conversation_lib.conv_templates:
947
+ conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
948
+ else:
949
+ conversation_lib.default_conversation = conversation_lib.conv_templates["llama3"]
950
+ rank0_print(f"Using conversation format: {conversation_lib.default_conversation.version}")
951
+
952
+
953
+
954
+ # if model_args.vision_tower is not None:
955
+ model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp)
956
+
957
+ ## generation vision tower
958
+ gen_vision_tower = model.get_gen_vision_tower()
959
+ gen_vision_tower.to(
960
+ dtype=torch.bfloat16 if training_args.bf16 else torch.float16,
961
+ device=training_args.device,
962
+ )
963
+ gen_vision_tower.requires_grad_(False)
964
+
965
+ data_args.gen_image_processor = gen_vision_tower.image_processor
966
+ data_args.image_processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct").image_processor
967
+
968
+ data_args.is_multimodal = True
969
+ data_args.n_query = model_args.n_query
970
+ data_args.n_und_query = model_args.n_und_query
971
+
972
+ model.config.image_aspect_ratio = data_args.image_aspect_ratio
973
+ model.config.tokenizer_padding_side = tokenizer.padding_side
974
+ model.config.tokenizer_model_max_length = tokenizer.model_max_length
975
+
976
+ model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
977
+
978
+ model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
979
+
980
+ # Calculate total parameters and trainable parameters
981
+ total_params = sum(p.numel() for p in model.get_model().parameters())
982
+ trainable_params = sum(p.numel() for p in model.get_model().parameters() if p.requires_grad)
983
+
984
+ print(f"Total parameters: {total_params}")
985
+ print(f"Trainable parameters: {trainable_params}")
986
+
987
+
988
+ model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
989
+ model.config.mm_projector_lr = training_args.mm_projector_lr
990
+ training_args.use_im_start_end = model_args.mm_use_im_start_end
991
+ model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
992
+ model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
993
+ model.config.pad_token_id = tokenizer.pad_token_id
994
+
995
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
996
+
997
+ trainer = blip3oTrainer(
998
+ model=model,
999
+ tokenizer=tokenizer,
1000
+ args=training_args,
1001
+ **data_module,
1002
+ )
1003
+ from tabulate import tabulate
1004
+
1005
+ if trainer.is_world_process_zero():
1006
+ stat = []
1007
+ for i, (n, p) in enumerate(trainer.model.named_parameters()):
1008
+ stat.append([i, n, p.shape, p.requires_grad])
1009
+ print(tabulate(stat, headers=["idx", "name", "shape", "trainable"]))
1010
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
1011
+ trainer.train(resume_from_checkpoint=True)
1012
+ else:
1013
+ trainer.train()
1014
+ trainer.save_state()
1015
+
1016
+ model.config.use_cache = True
1017
+ safe_save_model_for_hf_trainer(
1018
+ trainer=trainer,
1019
+ output_dir=training_args.output_dir,
1020
+ vision_tower=model_args.vision_tower,
1021
+ )
1022
+
1023
+
1024
+ if __name__ == "__main__":
1025
+ train()