# coding=utf-8 # Copyright (c) 2023 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Pretrain GPT""" import torch from functools import partial from megatron import get_args from megatron import print_rank_0 from megatron import get_timers from megatron import get_tokenizer from megatron import mpu from megatron.data.gpt_dataset import build_train_valid_test_datasets from megatron.model import GPTModel, GPTModelPipe from megatron.training import pretrain from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import average_losses_across_data_parallel_group from megatron.global_vars import get_current_device from megatron.enums import PositionEmbeddingType import deepspeed from deepspeed.runtime.utils import see_memory_usage import os import subprocess from torch import nn import torch.nn.functional as F def model_provider(pre_process=True, post_process=True, parallel_output=True): """Build the model.""" print_rank_0('building GPT model ...') see_memory_usage(f"Before Building Model", force=True) args = get_args() with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), remote_device=None if args.remote_device == 'none' else args.remote_device, config_dict_or_path=args.deepspeed_config, enabled=args.zero_stage == 3, mpu=mpu): current_device = get_current_device() if args.deepspeed and not args.no_pipeline_parallel: # verify --deepspeed_activation_checkpointing # mandatory! otherwise the model uses fork() mapping to Megatron's RNGStatesTrackerSingleton # while GPTModelPipe uses DS checkpoint activations that uses DS's RNGStatesTracker if args.checkpoint_activations and args.checkpoint_activations_granularity == "full": assert args.deepspeed_activation_checkpointing, \ "Flag --deepspeed_activation_checkpointing is mandatory when using GPTModelPipe" \ " with checkpoint activations granularity full." model = GPTModelPipe( num_tokentypes=0, parallel_output=parallel_output, ) # This is a hack to give us a reference to get_batch_pipe from within training.py # We need to call model.set_batch_fn after deepspeed.initialize model._megatron_batch_fn = get_batch_pipe # Predompute the attention mask and store it in args. This avoids having to # pipeline it as an activation during training. The mask is constant, and thus # we can reuse it. attention_mask = torch.tril(torch.ones( (1, args.seq_length, args.seq_length), device=current_device)).view( 1, 1, args.seq_length, args.seq_length) # Convert attention mask to binary: attention_mask = (attention_mask < 0.5) if args.fp16: attention_mask = attention_mask.half() elif args.bf16: attention_mask = attention_mask.bfloat16() if args.mask_tensor_adding: args.attn_mask = attention_mask * -10000.0 else: args.attn_mask = attention_mask.to(torch.bool) else: assert args.position_embedding_type != PositionEmbeddingType.alibi, \ "GPTModel doesn't yet support ALiBi positional encoding" model = GPTModel( num_tokentypes=0, parallel_output=parallel_output, pre_process=pre_process, post_process=post_process ).to(current_device) see_memory_usage(f"After Building Model", force=True) return model def get_batch(data_iterator): """Generate a batch""" args = get_args() tokenizer = get_tokenizer() # Items and their type. keys = ['text'] datatype = torch.int64 # Broadcast data. if data_iterator is not None: data = next(data_iterator) else: data = None data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. tokens_ = data_b['text'].long() if not args.use_seq_len_plus_one_tokens: labels = torch.roll(tokens_, shifts=-1, dims=1) labels[:, -1] = -1 tokens = tokens_ else: labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss, labels = labels, dummy_sample= None,) tokens[tokens == -1] = 0 labels[labels == -1] = 0 return tokens, labels, loss_mask, attention_mask, position_ids def get_batch_pipe(data): """Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`""" args = get_args() tokenizer = get_tokenizer() # Items and their type. keys = ['text'] datatype = torch.int64 # Broadcast data. data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. tokens_ = data_b['text'].long() if not args.use_seq_len_plus_one_tokens: labels = torch.roll(tokens_, shifts=-1, dims=1) labels[:, -1] = -1 tokens = tokens_ else: labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss, labels = labels, dummy_sample = None, ) tokens[tokens == -1] = 0 labels[labels == -1] = 0 if args.curriculum_learning and args.curriculum_seqlen < tokens.size()[1]: # seqlen-based curriculum learning # tokens, position_ids, labels, loss_mask have size [batch size, seqlen] tokens = tokens[:, :args.curriculum_seqlen].contiguous() position_ids = position_ids[:, :args.curriculum_seqlen].contiguous() if labels is not None: labels = labels[:, :args.curriculum_seqlen].contiguous() loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() return (tokens, position_ids, attention_mask), (labels, loss_mask) def loss_func(loss_mask, moe_loss, mos_loss, output_tensor): args = get_args() losses = output_tensor.float() loss_mask = loss_mask.view(-1).float() loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # Reduce loss for logging. averaged_loss = average_losses_across_data_parallel_group([loss]) if args.mos or args.kd: # assert max(args.num_experts) >= 1 loss = loss + moe_loss + mos_loss if args.mos: return loss, {'total loss': loss, 'lm loss': averaged_loss[0], 'moe loss': moe_loss, 'mos loss': mos_loss} elif args.kd: return loss, {'total loss': loss, 'lm loss': averaged_loss[0], 'moe loss': moe_loss, 'kd loss': mos_loss} print_rank_0('>>> total loss: {}, lm loss {}, kd loss {}'.format(loss, averaged_loss[0], mos_loss)) else: if max(args.num_experts) <= 1: return loss, {'lm loss': averaged_loss[0]} else: loss = loss + moe_loss return loss, {'lm loss': averaged_loss[0], 'moe loss': moe_loss} def calculate_mos_loss(args, stu_output, teacher_model, tokens, position_ids, attention_mask): mos_loss = 0 alpha = args.kd_alpha_ce beta = args.kd_beta_ce kd_temp = args.kd_temp if teacher_model: with torch.no_grad(): if args.curriculum_learning and args.curriculum_seqlen < args.seq_length: assert args.curriculum_seqlen is not None curriculum_seqlen = args.curriculum_seqlen tokens = tokens[:, :curriculum_seqlen].contiguous() position_ids = position_ids[:, :curriculum_seqlen].contiguous() attention_mask = attention_mask[:, :, :curriculum_seqlen, :curriculum_seqlen].contiguous() # No need to truncate labels as we do not need it for the teacher logits tea_output, *tea_other_losses = teacher_model(tokens, position_ids, attention_mask) assert stu_output.size() == tea_output.size(), 'teacher and student output should match in size. Student: {}, Teacher: {}, CL seq length {}'.format(stu_output.size(), tea_output.size(), args.curriculum_seqlen) student_logits = F.log_softmax(stu_output / kd_temp, dim=2) tea_logits = F.softmax(tea_output / kd_temp, dim=2) # The target logits is expected to be probabilities. If we use log_softmax, then we need to set target_log to true when initializing the KLDivLoss. mos_loss = kd_temp * kd_temp * nn.KLDivLoss(reduction='batchmean')(student_logits, tea_logits) mos_loss = mos_loss.div(args.seq_length) * beta return mos_loss def forward_step(data_iterator, model, teacher_model=None): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. timers('batch-generator').start() tokens, labels, loss_mask, attention_mask, position_ids = get_batch( data_iterator) timers('batch-generator').stop() if args.mos or args.kd: # The forward func can return either the loss or the logits, depending on whether passing in the labels or not. stu_output, *other_losses = model(tokens, position_ids, attention_mask) if args.curriculum_learning and args.curriculum_seqlen < args.seq_length: assert args.curriculum_seqlen is not None labels = labels[:, :args.curriculum_seqlen].contiguous() output_tensor = mpu.vocab_parallel_cross_entropy(stu_output.contiguous().float(), labels) else: output_tensor, *other_losses = model(tokens, position_ids, attention_mask, labels=labels) if args.curriculum_learning and args.curriculum_seqlen < args.seq_length: loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() moe_losses = [] for moe_loss in other_losses: if moe_loss is not None: moe_losses.append(moe_loss) moe_loss = sum(moe_losses) * args.moe_loss_coeff mos_loss = 0 if args.mos or args.kd: assert model.training mos_loss = calculate_mos_loss(args, stu_output, teacher_model, tokens, position_ids, attention_mask) # Output_tensor stores the standard loss, loos_func calculates the total loss. return output_tensor, partial(loss_func, loss_mask, moe_loss, mos_loss) def train_valid_test_datasets_provider(train_val_test_num_samples): """Build train, valid, and test datasets.""" args = get_args() print_rank_0('> building train, validation, and test datasets ' 'for GPT ...') train_ds, valid_ds, test_ds = build_train_valid_test_datasets( data_prefix=args.data_path, train_data_prefix=args.train_data_path, valid_data_prefix=args.valid_data_path, test_data_prefix=args.test_data_path, data_impl=args.data_impl, splits_string=args.split, train_valid_test_num_samples=train_val_test_num_samples, seq_length=args.seq_length, seed=args.seed, skip_warmup=(not args.mmap_warmup), use_seq_len_plus_one_tokens=args.use_seq_len_plus_one_tokens) print_rank_0("> finished creating GPT datasets ...") return train_ds, valid_ds, test_ds def command_exists(cmd): result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True) return result.wait() == 0 def git_ds_info(): from deepspeed.env_report import main as ds_report ds_report() # Write out version/git info git_hash_cmd = "git rev-parse --short HEAD" git_branch_cmd = "git rev-parse --abbrev-ref HEAD" if command_exists('git'): try: result = subprocess.check_output(git_hash_cmd, shell=True) git_hash = result.decode('utf-8').strip() result = subprocess.check_output(git_branch_cmd, shell=True) git_branch = result.decode('utf-8').strip() except subprocess.CalledProcessError: git_hash = "unknown" git_branch = "unknown" else: git_hash = "unknown" git_branch = "unknown" print(f'**** Git info for Megatron: git_hash={git_hash} git_branch={git_branch} ****') if __name__ == "__main__": git_ds_info() pretrain(train_valid_test_datasets_provider, model_provider, forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})