File size: 5,076 Bytes
6b02503 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert slow tokenizers checkpoints in fast (serialization format of the `tokenizers` library)"""
import argparse
import os
import transformers
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
from .utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
TOKENIZER_CLASSES = {
# Phi3 uses Llama tokenizer
name: getattr(transformers, "LlamaTokenizerFast" if name == "Phi3Tokenizer" else name + "Fast")
for name in SLOW_TO_FAST_CONVERTERS
}
def convert_slow_checkpoint_to_fast(tokenizer_name, checkpoint_name, dump_path, force_download):
if tokenizer_name is not None and tokenizer_name not in TOKENIZER_CLASSES:
raise ValueError(f"Unrecognized tokenizer name, should be one of {list(TOKENIZER_CLASSES.keys())}.")
if tokenizer_name is None:
tokenizer_names = TOKENIZER_CLASSES
else:
tokenizer_names = {tokenizer_name: getattr(transformers, tokenizer_name + "Fast")}
logger.info(f"Loading tokenizer classes: {tokenizer_names}")
for tokenizer_name in tokenizer_names:
tokenizer_class = TOKENIZER_CLASSES[tokenizer_name]
add_prefix = True
if checkpoint_name is None:
checkpoint_names = list(tokenizer_class.max_model_input_sizes.keys())
else:
checkpoint_names = [checkpoint_name]
logger.info(f"For tokenizer {tokenizer_class.__class__.__name__} loading checkpoints: {checkpoint_names}")
for checkpoint in checkpoint_names:
logger.info(f"Loading {tokenizer_class.__class__.__name__} {checkpoint}")
# Load tokenizer
tokenizer = tokenizer_class.from_pretrained(checkpoint, force_download=force_download)
# Save fast tokenizer
logger.info(f"Save fast tokenizer to {dump_path} with prefix {checkpoint} add_prefix {add_prefix}")
# For organization names we create sub-directories
if "/" in checkpoint:
checkpoint_directory, checkpoint_prefix_name = checkpoint.split("/")
dump_path_full = os.path.join(dump_path, checkpoint_directory)
elif add_prefix:
checkpoint_prefix_name = checkpoint
dump_path_full = dump_path
else:
checkpoint_prefix_name = None
dump_path_full = dump_path
logger.info(f"=> {dump_path_full} with prefix {checkpoint_prefix_name}, add_prefix {add_prefix}")
if checkpoint in list(tokenizer.pretrained_vocab_files_map.values())[0]:
file_path = list(tokenizer.pretrained_vocab_files_map.values())[0][checkpoint]
next_char = file_path.split(checkpoint)[-1][0]
if next_char == "/":
dump_path_full = os.path.join(dump_path_full, checkpoint_prefix_name)
checkpoint_prefix_name = None
logger.info(f"=> {dump_path_full} with prefix {checkpoint_prefix_name}, add_prefix {add_prefix}")
file_names = tokenizer.save_pretrained(
dump_path_full, legacy_format=False, filename_prefix=checkpoint_prefix_name
)
logger.info(f"=> File names {file_names}")
for file_name in file_names:
if not file_name.endswith("tokenizer.json"):
os.remove(file_name)
logger.info(f"=> removing {file_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--dump_path", default=None, type=str, required=True, help="Path to output generated fast tokenizer files."
)
parser.add_argument(
"--tokenizer_name",
default=None,
type=str,
help=(
f"Optional tokenizer type selected in the list of {list(TOKENIZER_CLASSES.keys())}. If not given, will "
"download and convert all the checkpoints from AWS."
),
)
parser.add_argument(
"--checkpoint_name",
default=None,
type=str,
help="Optional checkpoint name. If not given, will download and convert the canonical checkpoints from AWS.",
)
parser.add_argument(
"--force_download",
action="store_true",
help="Re-download checkpoints.",
)
args = parser.parse_args()
convert_slow_checkpoint_to_fast(args.tokenizer_name, args.checkpoint_name, args.dump_path, args.force_download)
|