File size: 11,889 Bytes
ea88892 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 |
import copy
import random
from xtuner.dataset.utils import get_bos_eos_token_ids
from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
import json
# def crop2square(pil_img):
# width, height = pil_img.width, pil_img.height
# if width > height:
# y0, y1 = 0, height
# x0 = random.randint(0, width - height) # [0, w - h]
# x1 = x0 + height # [h, w]
# else:
# x0, x1 = 0, width
# y0 = random.randint(0, height - width) # [0, h - w]
# y1 = y0 + width # [w, h]
# return pil_img.crop(box=(x0, y0, x1, y1))
def crop2square(pil_img):
width, height = pil_img.width, pil_img.height
short = min(width, height)
left = (width - short) // 2
upper = (height - short) // 2
return pil_img.crop((left, upper, left + short, upper + short))
def load_jsonl(json_file):
with open(json_file) as f:
lines = f.readlines()
data = []
for line in lines:
data.append(json.loads(line))
return data
def encode_fn_original(example,
tokenizer,
max_length=None,
image_length=1,
input_ids_with_output=True,
with_image_token=False,
truncation='right',
image_token_idx=None,
image_token_str="<image>"):
"""We only support the following three scenarios:
1. Incremental pretraining dataset.
example['conversation'] = [
{
'input': '',
'output': '### Human: Can you write xxx'
}
]
2. Single-turn conversation dataset.
example['conversation'] = [
{
'input': 'Give three tips for staying healthy.',
'output': '1.Eat a balanced diet xxx'
}
]
3. Multi-turn conversation dataset.
example['conversation'] = [
{
'input': 'Give three tips for staying healthy.',
'output': '1.Eat a balanced diet xxx'
},
{
'input': 'Please expand on the second point.',
'output': 'Here is an expanded explanation of the xxx'
}
]
"""
bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
if image_token_idx is None: # 如果没传,就退回库常量
image_token_idx = tokenizer.convert_tokens_to_ids("<image>")
is_multi_turn_conversation = len(example['conversation']) > 1
if is_multi_turn_conversation:
assert input_ids_with_output
input_ids, labels = [], []
next_needs_bos_token = True
for single_turn_conversation in example['conversation']:
input = single_turn_conversation['input']
if image_token_str in input and with_image_token:
chunk_encode = [
tokenizer.encode(chunk, add_special_tokens=False)
for chunk in input.split(image_token_str)
]
assert len(chunk_encode) == 2
input_encode = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
input_encode.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
# input_encode.append(IMAGE_TOKEN_INDEX)
input_encode += [image_token_idx] * image_length
else:
input_encode = tokenizer.encode(input, add_special_tokens=False)
if next_needs_bos_token:
input_ids += bos_token_id
labels += [IGNORE_INDEX] * len(bos_token_id)
input_ids += input_encode
labels += [IGNORE_INDEX] * len(input_encode)
if input_ids_with_output and 'output' in single_turn_conversation:
# Add output
output_with_loss = single_turn_conversation.get(
'output_with_loss', True)
output = single_turn_conversation['output']
if image_token_str in output and with_image_token:
chunk_encode = [
tokenizer.encode(chunk, add_special_tokens=False)
for chunk in output.split(image_token_str)
]
assert len(chunk_encode) == 2
output_encode = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
output_encode.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
output_encode += [image_token_idx] * image_length
else:
output_encode = tokenizer.encode(output, add_special_tokens=False)
# output_encode = tokenizer.encode(output, add_special_tokens=False)
input_ids += output_encode
if output_with_loss:
labels += copy.deepcopy(output_encode)
else:
labels += [IGNORE_INDEX] * len(output_encode)
# Add EOS_TOKEN (with loss)
if single_turn_conversation.get('need_eos_token', True):
next_needs_bos_token = True
input_ids += eos_token_id
if output_with_loss:
labels += copy.deepcopy(eos_token_id)
else:
labels += [IGNORE_INDEX] * len(eos_token_id)
else:
next_needs_bos_token = False
# Add SEP (without loss)
sep = single_turn_conversation.get('sep', '')
if sep != '':
sep_encode = tokenizer.encode(sep, add_special_tokens=False)
input_ids += sep_encode
labels += [IGNORE_INDEX] * len(sep_encode)
if max_length is not None and len(input_ids) > max_length:
if truncation == 'right':
input_ids = input_ids[:max_length]
labels = labels[:max_length]
elif truncation == 'left':
input_ids = input_ids[-max_length:]
labels = labels[-max_length:]
else:
assert truncation is None
return {'input_ids': input_ids, 'labels': labels}
def encode_fn(
example,
tokenizer,
prompt_template=None,
max_length=None,
image_length=1,
input_ids_with_output=True,
with_image_token=True,
truncation='right',
image_token_idx=None,
image_token_str="<image>",
):
"""
A versatile encoding function for both image-to-text (conversation) and text-to-image/image-editing tasks.
- Image-to-Text: example = {"conversation": [...]}, outputs input_ids + labels.
- Text-to-Image/Editing: example = str (raw_text prompt), outputs input_ids + labels (with IGNORE_INDEX).
"""
# assert image_token_idx is not None, "Must pass image_token_idx explicitly"
# print(f"[DEBUG] image_token_idx = {image_token_idx}")
if image_token_idx is None:
tokenizer.add_tokens([image_token_str], special_tokens=True)
image_token_idx = tokenizer.convert_tokens_to_ids(image_token_str)
if isinstance(example, str):
assert prompt_template is not None, \
"prompt_template 不能为空(text2image/image-editing)"
# 1) 构造 prompt
# 直接在最前面加一个 <image> token,
# 然后空一行,再拼原始文本
prompt = f"{example.strip()}"
# 用模板包装
prompt = prompt_template["INSTRUCTION"].format(input=prompt)
# 2) 用 tokenizer 编码(不要让 tokenizer 把 <image> 当成普通字符切分)
# 一种简单做法:先去掉 tokenizer 里的特殊 token,再手动拼接
text_ids = tokenizer.encode(
prompt,
add_special_tokens=False,
truncation=True,
max_length=(max_length - image_length) if max_length else None
)
# 把 <image> token id 插到最前面(或者你想要的位置)
input_ids = [image_token_idx] * image_length + text_ids
# 3) 如果超长,直接截断
if max_length is not None and len(input_ids) > max_length:
input_ids = input_ids[:max_length]
# 4) attention_mask
attention_mask = [1] * len(input_ids)
return {"input_ids": input_ids, "attention_mask": attention_mask}
# --- Image-to-text task: multi-turn conversation structure ---
assert isinstance(example, dict) and "conversation" in example
bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
is_multi_turn = len(example["conversation"]) > 1
if is_multi_turn:
assert input_ids_with_output
input_ids, labels = [], []
next_needs_bos_token = True
for single_turn in example["conversation"]:
input_text = single_turn["input"]
# ==== Encode input ====
if with_image_token and image_token_str in input_text:
chunks = input_text.split(image_token_str)
chunk_encoded = [tokenizer.encode(c, add_special_tokens=False) for c in chunks]
assert len(chunk_encoded) >= 2
input_encode = []
for i, chunk in enumerate(chunk_encoded):
input_encode.extend(chunk)
if i < len(chunk_encoded) - 1:
input_encode.extend([image_token_idx] * image_length)
else:
input_encode = tokenizer.encode(input_text, add_special_tokens=False)
if next_needs_bos_token:
input_ids.extend(bos_token_id)
labels.extend([IGNORE_INDEX] * len(bos_token_id))
input_ids.extend(input_encode)
labels.extend([IGNORE_INDEX] * len(input_encode))
# ==== Encode output ====
if input_ids_with_output and "output" in single_turn:
output = single_turn["output"]
output_with_loss = single_turn.get("output_with_loss", True)
if with_image_token and image_token_str in output:
chunks = output.split(image_token_str)
chunk_encoded = [tokenizer.encode(c, add_special_tokens=False) for c in chunks]
assert len(chunk_encoded) >= 2
output_encode = []
for i, chunk in enumerate(chunk_encoded):
output_encode.extend(chunk)
if i < len(chunk_encoded) - 1:
output_encode.extend([image_token_idx] * image_length)
else:
output_encode = tokenizer.encode(output, add_special_tokens=False)
input_ids.extend(output_encode)
if output_with_loss:
labels.extend(output_encode.copy())
else:
labels.extend([IGNORE_INDEX] * len(output_encode))
# ==== Append EOS ====
if single_turn.get("need_eos_token", True):
next_needs_bos_token = True
input_ids.extend(eos_token_id)
if output_with_loss:
labels.extend(eos_token_id.copy())
else:
labels.extend([IGNORE_INDEX] * len(eos_token_id))
else:
next_needs_bos_token = False
# ==== Append separator ====
sep = single_turn.get("sep", "")
if sep:
sep_encoded = tokenizer.encode(sep, add_special_tokens=False)
input_ids.extend(sep_encoded)
labels.extend([IGNORE_INDEX] * len(sep_encoded))
# ==== Truncation ====
if max_length is not None and len(input_ids) > max_length:
if truncation == "right":
input_ids = input_ids[:max_length]
labels = labels[:max_length]
elif truncation == "left":
input_ids = input_ids[-max_length:]
labels = labels[-max_length:]
else:
raise ValueError("truncation must be 'left', 'right', or None")
return {"input_ids": input_ids, "labels": labels} |