Spaces:
dreroc
/
Running on Zero

UniPic / src /datasets /utils.py
yichenchenchen's picture
Upload 25 files
ea88892 verified
import copy
import random
from xtuner.dataset.utils import get_bos_eos_token_ids
from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
import json
# def crop2square(pil_img):
# width, height = pil_img.width, pil_img.height
# if width > height:
# y0, y1 = 0, height
# x0 = random.randint(0, width - height) # [0, w - h]
# x1 = x0 + height # [h, w]
# else:
# x0, x1 = 0, width
# y0 = random.randint(0, height - width) # [0, h - w]
# y1 = y0 + width # [w, h]
# return pil_img.crop(box=(x0, y0, x1, y1))
def crop2square(pil_img):
width, height = pil_img.width, pil_img.height
short = min(width, height)
left = (width - short) // 2
upper = (height - short) // 2
return pil_img.crop((left, upper, left + short, upper + short))
def load_jsonl(json_file):
with open(json_file) as f:
lines = f.readlines()
data = []
for line in lines:
data.append(json.loads(line))
return data
def encode_fn_original(example,
tokenizer,
max_length=None,
image_length=1,
input_ids_with_output=True,
with_image_token=False,
truncation='right',
image_token_idx=None,
image_token_str="<image>"):
"""We only support the following three scenarios:
1. Incremental pretraining dataset.
example['conversation'] = [
{
'input': '',
'output': '### Human: Can you write xxx'
}
]
2. Single-turn conversation dataset.
example['conversation'] = [
{
'input': 'Give three tips for staying healthy.',
'output': '1.Eat a balanced diet xxx'
}
]
3. Multi-turn conversation dataset.
example['conversation'] = [
{
'input': 'Give three tips for staying healthy.',
'output': '1.Eat a balanced diet xxx'
},
{
'input': 'Please expand on the second point.',
'output': 'Here is an expanded explanation of the xxx'
}
]
"""
bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
if image_token_idx is None: # 如果没传,就退回库常量
image_token_idx = tokenizer.convert_tokens_to_ids("<image>")
is_multi_turn_conversation = len(example['conversation']) > 1
if is_multi_turn_conversation:
assert input_ids_with_output
input_ids, labels = [], []
next_needs_bos_token = True
for single_turn_conversation in example['conversation']:
input = single_turn_conversation['input']
if image_token_str in input and with_image_token:
chunk_encode = [
tokenizer.encode(chunk, add_special_tokens=False)
for chunk in input.split(image_token_str)
]
assert len(chunk_encode) == 2
input_encode = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
input_encode.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
# input_encode.append(IMAGE_TOKEN_INDEX)
input_encode += [image_token_idx] * image_length
else:
input_encode = tokenizer.encode(input, add_special_tokens=False)
if next_needs_bos_token:
input_ids += bos_token_id
labels += [IGNORE_INDEX] * len(bos_token_id)
input_ids += input_encode
labels += [IGNORE_INDEX] * len(input_encode)
if input_ids_with_output and 'output' in single_turn_conversation:
# Add output
output_with_loss = single_turn_conversation.get(
'output_with_loss', True)
output = single_turn_conversation['output']
if image_token_str in output and with_image_token:
chunk_encode = [
tokenizer.encode(chunk, add_special_tokens=False)
for chunk in output.split(image_token_str)
]
assert len(chunk_encode) == 2
output_encode = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
output_encode.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
output_encode += [image_token_idx] * image_length
else:
output_encode = tokenizer.encode(output, add_special_tokens=False)
# output_encode = tokenizer.encode(output, add_special_tokens=False)
input_ids += output_encode
if output_with_loss:
labels += copy.deepcopy(output_encode)
else:
labels += [IGNORE_INDEX] * len(output_encode)
# Add EOS_TOKEN (with loss)
if single_turn_conversation.get('need_eos_token', True):
next_needs_bos_token = True
input_ids += eos_token_id
if output_with_loss:
labels += copy.deepcopy(eos_token_id)
else:
labels += [IGNORE_INDEX] * len(eos_token_id)
else:
next_needs_bos_token = False
# Add SEP (without loss)
sep = single_turn_conversation.get('sep', '')
if sep != '':
sep_encode = tokenizer.encode(sep, add_special_tokens=False)
input_ids += sep_encode
labels += [IGNORE_INDEX] * len(sep_encode)
if max_length is not None and len(input_ids) > max_length:
if truncation == 'right':
input_ids = input_ids[:max_length]
labels = labels[:max_length]
elif truncation == 'left':
input_ids = input_ids[-max_length:]
labels = labels[-max_length:]
else:
assert truncation is None
return {'input_ids': input_ids, 'labels': labels}
def encode_fn(
example,
tokenizer,
prompt_template=None,
max_length=None,
image_length=1,
input_ids_with_output=True,
with_image_token=True,
truncation='right',
image_token_idx=None,
image_token_str="<image>",
):
"""
A versatile encoding function for both image-to-text (conversation) and text-to-image/image-editing tasks.
- Image-to-Text: example = {"conversation": [...]}, outputs input_ids + labels.
- Text-to-Image/Editing: example = str (raw_text prompt), outputs input_ids + labels (with IGNORE_INDEX).
"""
# assert image_token_idx is not None, "Must pass image_token_idx explicitly"
# print(f"[DEBUG] image_token_idx = {image_token_idx}")
if image_token_idx is None:
tokenizer.add_tokens([image_token_str], special_tokens=True)
image_token_idx = tokenizer.convert_tokens_to_ids(image_token_str)
if isinstance(example, str):
assert prompt_template is not None, \
"prompt_template 不能为空(text2image/image-editing)"
# 1) 构造 prompt
# 直接在最前面加一个 <image> token,
# 然后空一行,再拼原始文本
prompt = f"{example.strip()}"
# 用模板包装
prompt = prompt_template["INSTRUCTION"].format(input=prompt)
# 2) 用 tokenizer 编码(不要让 tokenizer 把 <image> 当成普通字符切分)
# 一种简单做法:先去掉 tokenizer 里的特殊 token,再手动拼接
text_ids = tokenizer.encode(
prompt,
add_special_tokens=False,
truncation=True,
max_length=(max_length - image_length) if max_length else None
)
# 把 <image> token id 插到最前面(或者你想要的位置)
input_ids = [image_token_idx] * image_length + text_ids
# 3) 如果超长,直接截断
if max_length is not None and len(input_ids) > max_length:
input_ids = input_ids[:max_length]
# 4) attention_mask
attention_mask = [1] * len(input_ids)
return {"input_ids": input_ids, "attention_mask": attention_mask}
# --- Image-to-text task: multi-turn conversation structure ---
assert isinstance(example, dict) and "conversation" in example
bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
is_multi_turn = len(example["conversation"]) > 1
if is_multi_turn:
assert input_ids_with_output
input_ids, labels = [], []
next_needs_bos_token = True
for single_turn in example["conversation"]:
input_text = single_turn["input"]
# ==== Encode input ====
if with_image_token and image_token_str in input_text:
chunks = input_text.split(image_token_str)
chunk_encoded = [tokenizer.encode(c, add_special_tokens=False) for c in chunks]
assert len(chunk_encoded) >= 2
input_encode = []
for i, chunk in enumerate(chunk_encoded):
input_encode.extend(chunk)
if i < len(chunk_encoded) - 1:
input_encode.extend([image_token_idx] * image_length)
else:
input_encode = tokenizer.encode(input_text, add_special_tokens=False)
if next_needs_bos_token:
input_ids.extend(bos_token_id)
labels.extend([IGNORE_INDEX] * len(bos_token_id))
input_ids.extend(input_encode)
labels.extend([IGNORE_INDEX] * len(input_encode))
# ==== Encode output ====
if input_ids_with_output and "output" in single_turn:
output = single_turn["output"]
output_with_loss = single_turn.get("output_with_loss", True)
if with_image_token and image_token_str in output:
chunks = output.split(image_token_str)
chunk_encoded = [tokenizer.encode(c, add_special_tokens=False) for c in chunks]
assert len(chunk_encoded) >= 2
output_encode = []
for i, chunk in enumerate(chunk_encoded):
output_encode.extend(chunk)
if i < len(chunk_encoded) - 1:
output_encode.extend([image_token_idx] * image_length)
else:
output_encode = tokenizer.encode(output, add_special_tokens=False)
input_ids.extend(output_encode)
if output_with_loss:
labels.extend(output_encode.copy())
else:
labels.extend([IGNORE_INDEX] * len(output_encode))
# ==== Append EOS ====
if single_turn.get("need_eos_token", True):
next_needs_bos_token = True
input_ids.extend(eos_token_id)
if output_with_loss:
labels.extend(eos_token_id.copy())
else:
labels.extend([IGNORE_INDEX] * len(eos_token_id))
else:
next_needs_bos_token = False
# ==== Append separator ====
sep = single_turn.get("sep", "")
if sep:
sep_encoded = tokenizer.encode(sep, add_special_tokens=False)
input_ids.extend(sep_encoded)
labels.extend([IGNORE_INDEX] * len(sep_encoded))
# ==== Truncation ====
if max_length is not None and len(input_ids) > max_length:
if truncation == "right":
input_ids = input_ids[:max_length]
labels = labels[:max_length]
elif truncation == "left":
input_ids = input_ids[-max_length:]
labels = labels[-max_length:]
else:
raise ValueError("truncation must be 'left', 'right', or None")
return {"input_ids": input_ids, "labels": labels}