# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Sample Generate GPT""" import deepspeed import os import sys sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) from megatron import get_args from megatron import print_rank_0 from megatron import get_tokenizer from megatron.core import mpu from megatron.checkpointing import load_checkpoint from megatron.initialize import initialize_megatron from megatron.model import GPTModel from megatron.training import get_model from megatron.text_generation_utils import generate_and_write_samples_unconditional from megatron.text_generation_utils import generate_samples_input_from_file from megatron.text_generation_utils import generate_samples_interactive import deepspeed import torch from megatron.arguments import core_transformer_config_from_args from megatron import get_args def model_provider(pre_process=True, post_process=True): """Build the model.""" args = get_args() config = core_transformer_config_from_args(args) print_rank_0('building GPT model ...') model = GPTModel(config=config, num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process, return_moe_loss=False) # we need to set "return_moe_loss" for the inference_mode return model def add_text_generate_args(parser): """Text generation arguments.""" group = parser.add_argument_group(title='text generation') group.add_argument("--temperature", type=float, default=1.0, help='Sampling temperature.') group.add_argument("--greedy", action='store_true', default=False, help='Use greedy sampling.') group.add_argument("--top_p", type=float, default=0.0, help='Top p sampling.') group.add_argument("--top_k", type=int, default=0, help='Top k sampling.') group.add_argument("--out-seq-length", type=int, default=1024, help='Size of the output generated text.') group.add_argument("--sample-input-file", type=str, default=None, help='Get input from file instead of interactive mode, ' 'each line is an input.') group.add_argument("--sample-output-file", type=str, default=None, help='Output file got from --sample-input-file') group.add_argument("--num-samples", type=int, default=0, help='Number of samples to generate unconditionally, ' 'defaults to 0 and interactive conditional sampling') group.add_argument("--genfile", type=str, help='Output file when generating unconditionally') group.add_argument("--recompute", action='store_true', help='During generation recompute all attention ' 'instead of using previously computed keys/values.') group.add_argument("--local_rank", type=int, default=0, help='local_rank') return parser def print_latency(latency_set, title=""): # 10 warmup queries latency_set = latency_set[10:] count = len(latency_set) if count > 0: latency_set.sort() n50 = (count - 1) * 0.5 + 1 n90 = (count - 1) * 0.9 + 1 n95 = (count - 1) * 0.95 + 1 n99 = (count - 1) * 0.99 + 1 n999 = (count - 1) * 0.999 + 1 avg = sum(latency_set) / count p50 = latency_set[int(n50) - 1] p90 = latency_set[int(n90) - 1] p95 = latency_set[int(n95) - 1] p99 = latency_set[int(n99) - 1] p999 = latency_set[int(n999) - 1] print("====== latency stats {0} ======", title) print("\tAvg Latency: {0:8.2f} ms".format(avg * 1000)) print("\tP50 Latency: {0:8.2f} ms".format(p50 * 1000)) print("\tP90 Latency: {0:8.2f} ms".format(p90 * 1000)) print("\tP95 Latency: {0:8.2f} ms".format(p95 * 1000)) print("\tP99 Latency: {0:8.2f} ms".format(p99 * 1000)) print("\t999 Latency: {0:8.2f} ms".format(p999 * 1000)) def main(): """Main program.""" latencies = [] model_latencies = [] single_token_latency = [] initialize_megatron(extra_args_provider=add_text_generate_args, args_defaults={'tokenizer_type': 'GPT2BPETokenizer', 'no_load_rng': True, 'no_load_optim': True}) args = get_args() if args.num_layers_per_virtual_pipeline_stage is not None: print("Interleaved pipeline schedule is not yet supported for text generation.") exit() # Set up model and load checkpoint. model = get_model(model_provider) if args.load is not None: _ = load_checkpoint(model, None, None) assert len(model) == 1, "Above condition should have caught this" model = model[0] if args.ds_inference: model = ds_inference(model, args) print('> DeepSpeed Inference engine initialized') # Generate samples. if args.num_samples == 0: args.micro_batch_size = 1 if args.sample_input_file != None: generate_samples_input_from_file(model) else: generate_samples_interactive(model) else: generate_and_write_samples_unconditional(model, latencies, single_token_latency, model_latencies) #if torch.cuda.current_device() == 0: if torch.distributed.get_rank() == 0: print_latency(latencies) print_latency(model_latencies, "model_latencies") print_latency(single_token_latency, "single_token_latency") def ds_inference(model, args): import megatron.model as mm engine = deepspeed.init_inference(model=model, mp_size=args.tensor_model_parallel_size, tensor_parallel={"mpu": mpu}, dtype=torch.half, replace_with_kernel_inject=True, moe_experts=args.num_experts, moe_type=args.mlp_type) return engine.module if __name__ == "__main__": main()