import copy import random from xtuner.dataset.utils import get_bos_eos_token_ids from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX import json # def crop2square(pil_img): # width, height = pil_img.width, pil_img.height # if width > height: # y0, y1 = 0, height # x0 = random.randint(0, width - height) # [0, w - h] # x1 = x0 + height # [h, w] # else: # x0, x1 = 0, width # y0 = random.randint(0, height - width) # [0, h - w] # y1 = y0 + width # [w, h] # return pil_img.crop(box=(x0, y0, x1, y1)) def crop2square(pil_img): width, height = pil_img.width, pil_img.height short = min(width, height) left = (width - short) // 2 upper = (height - short) // 2 return pil_img.crop((left, upper, left + short, upper + short)) def load_jsonl(json_file): with open(json_file) as f: lines = f.readlines() data = [] for line in lines: data.append(json.loads(line)) return data def encode_fn_original(example, tokenizer, max_length=None, image_length=1, input_ids_with_output=True, with_image_token=False, truncation='right', image_token_idx=None, image_token_str=""): """We only support the following three scenarios: 1. Incremental pretraining dataset. example['conversation'] = [ { 'input': '', 'output': '### Human: Can you write xxx' } ] 2. Single-turn conversation dataset. example['conversation'] = [ { 'input': 'Give three tips for staying healthy.', 'output': '1.Eat a balanced diet xxx' } ] 3. Multi-turn conversation dataset. example['conversation'] = [ { 'input': 'Give three tips for staying healthy.', 'output': '1.Eat a balanced diet xxx' }, { 'input': 'Please expand on the second point.', 'output': 'Here is an expanded explanation of the xxx' } ] """ bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer) if image_token_idx is None: # 如果没传,就退回库常量 image_token_idx = tokenizer.convert_tokens_to_ids("") is_multi_turn_conversation = len(example['conversation']) > 1 if is_multi_turn_conversation: assert input_ids_with_output input_ids, labels = [], [] next_needs_bos_token = True for single_turn_conversation in example['conversation']: input = single_turn_conversation['input'] if image_token_str in input and with_image_token: chunk_encode = [ tokenizer.encode(chunk, add_special_tokens=False) for chunk in input.split(image_token_str) ] assert len(chunk_encode) == 2 input_encode = [] for idx, cur_chunk_encode in enumerate(chunk_encode): input_encode.extend(cur_chunk_encode) if idx != len(chunk_encode) - 1: # input_encode.append(IMAGE_TOKEN_INDEX) input_encode += [image_token_idx] * image_length else: input_encode = tokenizer.encode(input, add_special_tokens=False) if next_needs_bos_token: input_ids += bos_token_id labels += [IGNORE_INDEX] * len(bos_token_id) input_ids += input_encode labels += [IGNORE_INDEX] * len(input_encode) if input_ids_with_output and 'output' in single_turn_conversation: # Add output output_with_loss = single_turn_conversation.get( 'output_with_loss', True) output = single_turn_conversation['output'] if image_token_str in output and with_image_token: chunk_encode = [ tokenizer.encode(chunk, add_special_tokens=False) for chunk in output.split(image_token_str) ] assert len(chunk_encode) == 2 output_encode = [] for idx, cur_chunk_encode in enumerate(chunk_encode): output_encode.extend(cur_chunk_encode) if idx != len(chunk_encode) - 1: output_encode += [image_token_idx] * image_length else: output_encode = tokenizer.encode(output, add_special_tokens=False) # output_encode = tokenizer.encode(output, add_special_tokens=False) input_ids += output_encode if output_with_loss: labels += copy.deepcopy(output_encode) else: labels += [IGNORE_INDEX] * len(output_encode) # Add EOS_TOKEN (with loss) if single_turn_conversation.get('need_eos_token', True): next_needs_bos_token = True input_ids += eos_token_id if output_with_loss: labels += copy.deepcopy(eos_token_id) else: labels += [IGNORE_INDEX] * len(eos_token_id) else: next_needs_bos_token = False # Add SEP (without loss) sep = single_turn_conversation.get('sep', '') if sep != '': sep_encode = tokenizer.encode(sep, add_special_tokens=False) input_ids += sep_encode labels += [IGNORE_INDEX] * len(sep_encode) if max_length is not None and len(input_ids) > max_length: if truncation == 'right': input_ids = input_ids[:max_length] labels = labels[:max_length] elif truncation == 'left': input_ids = input_ids[-max_length:] labels = labels[-max_length:] else: assert truncation is None return {'input_ids': input_ids, 'labels': labels} def encode_fn( example, tokenizer, prompt_template=None, max_length=None, image_length=1, input_ids_with_output=True, with_image_token=True, truncation='right', image_token_idx=None, image_token_str="", ): """ A versatile encoding function for both image-to-text (conversation) and text-to-image/image-editing tasks. - Image-to-Text: example = {"conversation": [...]}, outputs input_ids + labels. - Text-to-Image/Editing: example = str (raw_text prompt), outputs input_ids + labels (with IGNORE_INDEX). """ # assert image_token_idx is not None, "Must pass image_token_idx explicitly" # print(f"[DEBUG] image_token_idx = {image_token_idx}") if image_token_idx is None: tokenizer.add_tokens([image_token_str], special_tokens=True) image_token_idx = tokenizer.convert_tokens_to_ids(image_token_str) if isinstance(example, str): assert prompt_template is not None, \ "prompt_template 不能为空(text2image/image-editing)" # 1) 构造 prompt # 直接在最前面加一个 token, # 然后空一行,再拼原始文本 prompt = f"{example.strip()}" # 用模板包装 prompt = prompt_template["INSTRUCTION"].format(input=prompt) # 2) 用 tokenizer 编码(不要让 tokenizer 把 当成普通字符切分) # 一种简单做法:先去掉 tokenizer 里的特殊 token,再手动拼接 text_ids = tokenizer.encode( prompt, add_special_tokens=False, truncation=True, max_length=(max_length - image_length) if max_length else None ) # 把 token id 插到最前面(或者你想要的位置) input_ids = [image_token_idx] * image_length + text_ids # 3) 如果超长,直接截断 if max_length is not None and len(input_ids) > max_length: input_ids = input_ids[:max_length] # 4) attention_mask attention_mask = [1] * len(input_ids) return {"input_ids": input_ids, "attention_mask": attention_mask} # --- Image-to-text task: multi-turn conversation structure --- assert isinstance(example, dict) and "conversation" in example bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer) is_multi_turn = len(example["conversation"]) > 1 if is_multi_turn: assert input_ids_with_output input_ids, labels = [], [] next_needs_bos_token = True for single_turn in example["conversation"]: input_text = single_turn["input"] # ==== Encode input ==== if with_image_token and image_token_str in input_text: chunks = input_text.split(image_token_str) chunk_encoded = [tokenizer.encode(c, add_special_tokens=False) for c in chunks] assert len(chunk_encoded) >= 2 input_encode = [] for i, chunk in enumerate(chunk_encoded): input_encode.extend(chunk) if i < len(chunk_encoded) - 1: input_encode.extend([image_token_idx] * image_length) else: input_encode = tokenizer.encode(input_text, add_special_tokens=False) if next_needs_bos_token: input_ids.extend(bos_token_id) labels.extend([IGNORE_INDEX] * len(bos_token_id)) input_ids.extend(input_encode) labels.extend([IGNORE_INDEX] * len(input_encode)) # ==== Encode output ==== if input_ids_with_output and "output" in single_turn: output = single_turn["output"] output_with_loss = single_turn.get("output_with_loss", True) if with_image_token and image_token_str in output: chunks = output.split(image_token_str) chunk_encoded = [tokenizer.encode(c, add_special_tokens=False) for c in chunks] assert len(chunk_encoded) >= 2 output_encode = [] for i, chunk in enumerate(chunk_encoded): output_encode.extend(chunk) if i < len(chunk_encoded) - 1: output_encode.extend([image_token_idx] * image_length) else: output_encode = tokenizer.encode(output, add_special_tokens=False) input_ids.extend(output_encode) if output_with_loss: labels.extend(output_encode.copy()) else: labels.extend([IGNORE_INDEX] * len(output_encode)) # ==== Append EOS ==== if single_turn.get("need_eos_token", True): next_needs_bos_token = True input_ids.extend(eos_token_id) if output_with_loss: labels.extend(eos_token_id.copy()) else: labels.extend([IGNORE_INDEX] * len(eos_token_id)) else: next_needs_bos_token = False # ==== Append separator ==== sep = single_turn.get("sep", "") if sep: sep_encoded = tokenizer.encode(sep, add_special_tokens=False) input_ids.extend(sep_encoded) labels.extend([IGNORE_INDEX] * len(sep_encoded)) # ==== Truncation ==== if max_length is not None and len(input_ids) > max_length: if truncation == "right": input_ids = input_ids[:max_length] labels = labels[:max_length] elif truncation == "left": input_ids = input_ids[-max_length:] labels = labels[-max_length:] else: raise ValueError("truncation must be 'left', 'right', or None") return {"input_ids": input_ids, "labels": labels}