Spaces:
Configuration error
Configuration error
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# | |
# Copyright (c) 2023 Intel Corporation | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
import os | |
import argparse | |
import tempfile | |
from transformers import AutoTokenizer | |
from filelock import FileLock | |
from downloader.datasets import DataDownloader | |
from downloader.models import ModelDownloader | |
from tlt.distributed.pytorch.utils.pyt_distributed_utils import ( | |
DistributedTorch, | |
DistributedTrainingArguments, | |
HorovodTrainer | |
) | |
if __name__ == "__main__": | |
default_data_dir = os.path.join(tempfile.gettempdir(), 'data') | |
default_output_dir = os.path.join(tempfile.gettempdir(), 'output') | |
for d in [default_data_dir, default_output_dir]: | |
if not os.path.exists(d): | |
os.makedirs(d) | |
def directory_path(path): | |
if os.path.isdir(path): | |
return path | |
else: | |
raise argparse.ArgumentTypeError("'{}' is not a valid directory path.".format(path)) | |
print("******Distributed Training*****") | |
description = 'Distributed training with PyTorch.' | |
parser = argparse.ArgumentParser(description=description) | |
parser.add_argument('--master_addr', type=str, required=False, help="Master node to run this script") | |
parser.add_argument('--master_port', type=str, required=False, default='29500', help='Master port') | |
parser.add_argument('--backend', type=str, required=False, default='ccl', help='Type of backend to use ' | |
'(default: ccl)') | |
parser.add_argument('--use-case', '--use_case', type=str, required=True, | |
help='Use case (image_classification|text_classification)') | |
parser.add_argument('--epochs', type=int, required=False, default=1, help='Total epochs to train the model') | |
parser.add_argument('--batch_size', type=int, required=False, default=128, | |
help='Global batch size to distribute data (default: 128)') | |
parser.add_argument('--disable_ipex', action='store_true', required=False, help="Disables IPEX optimization to " | |
"the model. No effect when given --use-horovod as horovod with IPEX isn't supported.") | |
parser.add_argument('--tlt_saved_objects_dir', type=directory_path, required=False, help='Path to TLT saved ' | |
'distributed objects. The path must be accessible to all the nodes. For example: mounted ' | |
'NFS drive. This arg is helpful when using TLT API/CLI. ' | |
'See DistributedTorch.load_saved_objects() for more information.') | |
parser.add_argument('--use-horovod', '--use_horovod', action='store_true', help='Use horovod for distributed ' | |
'training.') | |
parser.add_argument('--cuda', action='store_true', help='Use cuda device for distributed training') | |
parser.add_argument('--dataset-dir', '--dataset_dir', type=directory_path, default=default_data_dir, | |
help="Path to dataset directory to save/load tfds dataset. This arg is helpful if you " | |
"plan to use this as a stand-alone script. Custom dataset is not supported yet!") | |
parser.add_argument('--output-dir', '--output_dir', type=directory_path, default=default_output_dir, | |
help="Path to save the trained model and store logs. This arg is helpful if you " | |
"plan to use this as a stand-alone script") | |
parser.add_argument('--dataset-name', '--dataset_name', type=str, default=None, | |
help="Dataset name to load from torchvision/Huggingface. This arg is helpful if you " | |
"plan to use this as a stand-alone script. Custom dataset is not supported yet!") | |
parser.add_argument('--model-name', '--model_name', type=str, default=None, | |
help="Torchvision image classification model name " | |
"(or) Huggingface hub name for text classification models. This arg is helpful if you " | |
"plan to use this as a stand-alone script.") | |
parser.add_argument('--max_seq_length', type=int, default=128, | |
help='Maximum sequence length that the model will be used with for text classification') | |
args = parser.parse_args() | |
train_data = None | |
model = None | |
optimizer, loss = None, None | |
data_kwargs = {} | |
if args.tlt_saved_objects_dir is not None: | |
# Load the saved dataset and model objects | |
loaded_objects = DistributedTorch.load_saved_objects(args.tlt_saved_objects_dir) | |
train_data = loaded_objects.get('train_data') | |
model = loaded_objects['model'] | |
loss = loaded_objects['loss'] | |
optimizer = loaded_objects['optimizer'] | |
data_kwargs['is_preprocessed'] = True | |
else: | |
if args.dataset_name is None: | |
raise argparse.ArgumentError(args.dataset_name, "Please provide a dataset name to load from torchvision " | |
"(or) datasets using --dataset-name") | |
if args.model_name is None: | |
raise argparse.ArgumentError(args.model_name, "Please provide torchvision model name (or) " | |
"Huggingface hub name using --model-name") | |
catalog = 'torchvision' if args.use_case == 'image_classification' else 'hugging_face' | |
with FileLock(os.path.expanduser('~/.horovod_lock')): | |
train_data = DataDownloader(args.dataset_name, args.dataset_dir, catalog).download(split='train') | |
model = ModelDownloader(args.model_name, catalog, args.output_dir).download() | |
if args.use_case == 'text_classification': | |
data_kwargs['hf_tokenizer'] = AutoTokenizer.from_pretrained(args.model_name) | |
data_kwargs['max_seq_length'] = args.max_seq_length | |
data_kwargs['text_column_names'] = [c for c in train_data.column_names if c != 'label'] | |
data_kwargs['is_preprocessed'] = False | |
if args.use_horovod: | |
hvd_trainer = HorovodTrainer(args.cuda) | |
train_loader, train_sampler = hvd_trainer.prepare_data(train_data, args.use_case, args.batch_size, | |
**data_kwargs) | |
hvd_trainer.prepare_model(model, args.use_case, optimizer, loss) | |
hvd_trainer.fit(train_loader, train_sampler, args.use_case, args.epochs) | |
else: | |
# Launch distributed job | |
training_args = DistributedTrainingArguments( | |
dataset=train_data, | |
model=model, | |
criterion=loss, | |
optimizer=optimizer, | |
epochs=args.epochs, | |
batch_size=args.batch_size, | |
disable_ipex=args.disable_ipex | |
) | |
dt = DistributedTorch(use_case=args.use_case) | |
dt.launch_distributed_job(training_args, args.master_addr, args.master_port, args.backend) | |