|
import copy |
|
import random |
|
from xtuner.dataset.utils import get_bos_eos_token_ids |
|
from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX |
|
import json |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def crop2square(pil_img): |
|
width, height = pil_img.width, pil_img.height |
|
short = min(width, height) |
|
left = (width - short) // 2 |
|
upper = (height - short) // 2 |
|
return pil_img.crop((left, upper, left + short, upper + short)) |
|
def load_jsonl(json_file): |
|
with open(json_file) as f: |
|
lines = f.readlines() |
|
data = [] |
|
for line in lines: |
|
data.append(json.loads(line)) |
|
return data |
|
|
|
|
|
def encode_fn_original(example, |
|
tokenizer, |
|
max_length=None, |
|
image_length=1, |
|
input_ids_with_output=True, |
|
with_image_token=False, |
|
truncation='right', |
|
image_token_idx=None, |
|
image_token_str="<image>"): |
|
"""We only support the following three scenarios: |
|
|
|
1. Incremental pretraining dataset. |
|
example['conversation'] = [ |
|
{ |
|
'input': '', |
|
'output': '### Human: Can you write xxx' |
|
} |
|
] |
|
|
|
2. Single-turn conversation dataset. |
|
example['conversation'] = [ |
|
{ |
|
'input': 'Give three tips for staying healthy.', |
|
'output': '1.Eat a balanced diet xxx' |
|
} |
|
] |
|
|
|
3. Multi-turn conversation dataset. |
|
example['conversation'] = [ |
|
{ |
|
'input': 'Give three tips for staying healthy.', |
|
'output': '1.Eat a balanced diet xxx' |
|
}, |
|
{ |
|
'input': 'Please expand on the second point.', |
|
'output': 'Here is an expanded explanation of the xxx' |
|
} |
|
] |
|
""" |
|
bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer) |
|
if image_token_idx is None: |
|
image_token_idx = tokenizer.convert_tokens_to_ids("<image>") |
|
|
|
is_multi_turn_conversation = len(example['conversation']) > 1 |
|
if is_multi_turn_conversation: |
|
assert input_ids_with_output |
|
|
|
input_ids, labels = [], [] |
|
next_needs_bos_token = True |
|
for single_turn_conversation in example['conversation']: |
|
input = single_turn_conversation['input'] |
|
if image_token_str in input and with_image_token: |
|
chunk_encode = [ |
|
tokenizer.encode(chunk, add_special_tokens=False) |
|
for chunk in input.split(image_token_str) |
|
] |
|
assert len(chunk_encode) == 2 |
|
input_encode = [] |
|
for idx, cur_chunk_encode in enumerate(chunk_encode): |
|
input_encode.extend(cur_chunk_encode) |
|
if idx != len(chunk_encode) - 1: |
|
|
|
input_encode += [image_token_idx] * image_length |
|
|
|
else: |
|
input_encode = tokenizer.encode(input, add_special_tokens=False) |
|
if next_needs_bos_token: |
|
input_ids += bos_token_id |
|
labels += [IGNORE_INDEX] * len(bos_token_id) |
|
input_ids += input_encode |
|
labels += [IGNORE_INDEX] * len(input_encode) |
|
if input_ids_with_output and 'output' in single_turn_conversation: |
|
|
|
output_with_loss = single_turn_conversation.get( |
|
'output_with_loss', True) |
|
output = single_turn_conversation['output'] |
|
|
|
if image_token_str in output and with_image_token: |
|
chunk_encode = [ |
|
tokenizer.encode(chunk, add_special_tokens=False) |
|
for chunk in output.split(image_token_str) |
|
] |
|
assert len(chunk_encode) == 2 |
|
output_encode = [] |
|
for idx, cur_chunk_encode in enumerate(chunk_encode): |
|
output_encode.extend(cur_chunk_encode) |
|
if idx != len(chunk_encode) - 1: |
|
output_encode += [image_token_idx] * image_length |
|
else: |
|
output_encode = tokenizer.encode(output, add_special_tokens=False) |
|
|
|
input_ids += output_encode |
|
if output_with_loss: |
|
labels += copy.deepcopy(output_encode) |
|
else: |
|
labels += [IGNORE_INDEX] * len(output_encode) |
|
|
|
if single_turn_conversation.get('need_eos_token', True): |
|
next_needs_bos_token = True |
|
input_ids += eos_token_id |
|
if output_with_loss: |
|
labels += copy.deepcopy(eos_token_id) |
|
else: |
|
labels += [IGNORE_INDEX] * len(eos_token_id) |
|
else: |
|
next_needs_bos_token = False |
|
|
|
sep = single_turn_conversation.get('sep', '') |
|
if sep != '': |
|
sep_encode = tokenizer.encode(sep, add_special_tokens=False) |
|
input_ids += sep_encode |
|
labels += [IGNORE_INDEX] * len(sep_encode) |
|
|
|
if max_length is not None and len(input_ids) > max_length: |
|
if truncation == 'right': |
|
input_ids = input_ids[:max_length] |
|
labels = labels[:max_length] |
|
elif truncation == 'left': |
|
input_ids = input_ids[-max_length:] |
|
labels = labels[-max_length:] |
|
else: |
|
assert truncation is None |
|
return {'input_ids': input_ids, 'labels': labels} |
|
|
|
|
|
|
|
def encode_fn( |
|
example, |
|
tokenizer, |
|
prompt_template=None, |
|
max_length=None, |
|
image_length=1, |
|
input_ids_with_output=True, |
|
with_image_token=True, |
|
truncation='right', |
|
image_token_idx=None, |
|
image_token_str="<image>", |
|
): |
|
""" |
|
A versatile encoding function for both image-to-text (conversation) and text-to-image/image-editing tasks. |
|
|
|
- Image-to-Text: example = {"conversation": [...]}, outputs input_ids + labels. |
|
- Text-to-Image/Editing: example = str (raw_text prompt), outputs input_ids + labels (with IGNORE_INDEX). |
|
""" |
|
|
|
|
|
if image_token_idx is None: |
|
tokenizer.add_tokens([image_token_str], special_tokens=True) |
|
image_token_idx = tokenizer.convert_tokens_to_ids(image_token_str) |
|
|
|
if isinstance(example, str): |
|
assert prompt_template is not None, \ |
|
"prompt_template 不能为空(text2image/image-editing)" |
|
|
|
|
|
|
|
|
|
prompt = f"{example.strip()}" |
|
|
|
prompt = prompt_template["INSTRUCTION"].format(input=prompt) |
|
|
|
|
|
|
|
text_ids = tokenizer.encode( |
|
prompt, |
|
add_special_tokens=False, |
|
truncation=True, |
|
max_length=(max_length - image_length) if max_length else None |
|
) |
|
|
|
input_ids = [image_token_idx] * image_length + text_ids |
|
|
|
|
|
if max_length is not None and len(input_ids) > max_length: |
|
input_ids = input_ids[:max_length] |
|
|
|
|
|
attention_mask = [1] * len(input_ids) |
|
|
|
return {"input_ids": input_ids, "attention_mask": attention_mask} |
|
|
|
|
|
assert isinstance(example, dict) and "conversation" in example |
|
bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer) |
|
is_multi_turn = len(example["conversation"]) > 1 |
|
if is_multi_turn: |
|
assert input_ids_with_output |
|
|
|
input_ids, labels = [], [] |
|
next_needs_bos_token = True |
|
|
|
for single_turn in example["conversation"]: |
|
input_text = single_turn["input"] |
|
|
|
|
|
if with_image_token and image_token_str in input_text: |
|
chunks = input_text.split(image_token_str) |
|
chunk_encoded = [tokenizer.encode(c, add_special_tokens=False) for c in chunks] |
|
assert len(chunk_encoded) >= 2 |
|
input_encode = [] |
|
for i, chunk in enumerate(chunk_encoded): |
|
input_encode.extend(chunk) |
|
if i < len(chunk_encoded) - 1: |
|
input_encode.extend([image_token_idx] * image_length) |
|
else: |
|
input_encode = tokenizer.encode(input_text, add_special_tokens=False) |
|
|
|
if next_needs_bos_token: |
|
input_ids.extend(bos_token_id) |
|
labels.extend([IGNORE_INDEX] * len(bos_token_id)) |
|
|
|
input_ids.extend(input_encode) |
|
labels.extend([IGNORE_INDEX] * len(input_encode)) |
|
|
|
|
|
if input_ids_with_output and "output" in single_turn: |
|
output = single_turn["output"] |
|
output_with_loss = single_turn.get("output_with_loss", True) |
|
|
|
if with_image_token and image_token_str in output: |
|
chunks = output.split(image_token_str) |
|
chunk_encoded = [tokenizer.encode(c, add_special_tokens=False) for c in chunks] |
|
assert len(chunk_encoded) >= 2 |
|
output_encode = [] |
|
for i, chunk in enumerate(chunk_encoded): |
|
output_encode.extend(chunk) |
|
if i < len(chunk_encoded) - 1: |
|
output_encode.extend([image_token_idx] * image_length) |
|
else: |
|
output_encode = tokenizer.encode(output, add_special_tokens=False) |
|
|
|
input_ids.extend(output_encode) |
|
if output_with_loss: |
|
labels.extend(output_encode.copy()) |
|
else: |
|
labels.extend([IGNORE_INDEX] * len(output_encode)) |
|
|
|
|
|
if single_turn.get("need_eos_token", True): |
|
next_needs_bos_token = True |
|
input_ids.extend(eos_token_id) |
|
if output_with_loss: |
|
labels.extend(eos_token_id.copy()) |
|
else: |
|
labels.extend([IGNORE_INDEX] * len(eos_token_id)) |
|
else: |
|
next_needs_bos_token = False |
|
|
|
|
|
sep = single_turn.get("sep", "") |
|
if sep: |
|
sep_encoded = tokenizer.encode(sep, add_special_tokens=False) |
|
input_ids.extend(sep_encoded) |
|
labels.extend([IGNORE_INDEX] * len(sep_encoded)) |
|
|
|
|
|
if max_length is not None and len(input_ids) > max_length: |
|
if truncation == "right": |
|
input_ids = input_ids[:max_length] |
|
labels = labels[:max_length] |
|
elif truncation == "left": |
|
input_ids = input_ids[-max_length:] |
|
labels = labels[-max_length:] |
|
else: |
|
raise ValueError("truncation must be 'left', 'right', or None") |
|
|
|
return {"input_ids": input_ids, "labels": labels} |