Spaces:
Paused
Paused
# Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import logging | |
from typing import Callable, Literal, Optional, Union | |
from datasets import Dataset, Value | |
from transformers import AutoTokenizer | |
from ..trainer.utils import ConstantLengthDataset | |
FORMAT_MAPPING = { | |
"chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}], | |
"instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)}, | |
} | |
def conversations_formatting_function( | |
tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"], tools: Optional[list] = None | |
): | |
r""" | |
return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer | |
apply chat template to the dataset along with the schema of the list of functions in the tools list. | |
""" | |
def format_dataset(examples): | |
if isinstance(examples[messages_field][0], list): | |
output_texts = [] | |
for i in range(len(examples[messages_field])): | |
output_texts.append( | |
tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False, tools=tools) | |
) | |
return output_texts | |
else: | |
return tokenizer.apply_chat_template(examples[messages_field], tokenize=False, tools=tools) | |
return format_dataset | |
def instructions_formatting_function(tokenizer: AutoTokenizer): | |
r""" | |
return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer | |
apply chat template to the dataset | |
""" | |
def format_dataset(examples): | |
if isinstance(examples["prompt"], list): | |
output_texts = [] | |
for i in range(len(examples["prompt"])): | |
converted_sample = [ | |
{"role": "user", "content": examples["prompt"][i]}, | |
{"role": "assistant", "content": examples["completion"][i]}, | |
] | |
output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False)) | |
return output_texts | |
else: | |
converted_sample = [ | |
{"role": "user", "content": examples["prompt"]}, | |
{"role": "assistant", "content": examples["completion"]}, | |
] | |
return tokenizer.apply_chat_template(converted_sample, tokenize=False) | |
return format_dataset | |
def get_formatting_func_from_dataset( | |
dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer, tools: Optional[list] = None | |
) -> Optional[Callable]: | |
r""" | |
Finds the correct formatting function based on the dataset structure. Currently supported datasets are: | |
- `ChatML` with [{"role": str, "content": str}] | |
- `instruction` with [{"prompt": str, "completion": str}] | |
Args: | |
dataset (Dataset): User dataset | |
tokenizer (AutoTokenizer): Tokenizer used for formatting | |
Returns: | |
Callable: Formatting function if the dataset format is supported else None | |
""" | |
if isinstance(dataset, Dataset): | |
if "messages" in dataset.features: | |
if dataset.features["messages"] == FORMAT_MAPPING["chatml"]: | |
logging.info("Formatting dataset with chatml format") | |
return conversations_formatting_function(tokenizer, "messages", tools) | |
if "conversations" in dataset.features: | |
if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]: | |
logging.info("Formatting dataset with chatml format") | |
return conversations_formatting_function(tokenizer, "conversations", tools) | |
elif dataset.features == FORMAT_MAPPING["instruction"]: | |
logging.info("Formatting dataset with instruction format") | |
return instructions_formatting_function(tokenizer) | |
return None | |