|
import torch |
|
from xtuner.utils import IGNORE_INDEX |
|
from typing import Dict, Sequence |
|
from torch.nn.utils.rnn import pad_sequence |
|
from functools import partial |
|
from dataclasses import dataclass |
|
|
|
|
|
def collate_func_gen(instances: Sequence[Dict], |
|
pad_index: int = 151645): |
|
pixel_values_src, pixel_values, input_ids, input_lengths = [], [], [], [] |
|
|
|
for example in instances: |
|
|
|
if 'pixel_values_src' in example: |
|
pixel_values_src.append(example.pop('pixel_values_src')) |
|
if 'pixel_values' in example: |
|
pixel_values.append(example.pop('pixel_values')) |
|
|
|
input_lengths.append(len(example['input_ids'])) |
|
input_ids.append(example.pop('input_ids')) |
|
|
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_index) |
|
attention_mask = torch.zeros_like(input_ids).bool() |
|
for i in range(len(input_ids)): |
|
attention_mask[i, :input_lengths[i]] = True |
|
|
|
data_dict = { |
|
'input_ids': input_ids, |
|
'attention_mask': attention_mask, |
|
} |
|
|
|
if pixel_values: |
|
data_dict['pixel_values'] = torch.stack(pixel_values) |
|
if pixel_values_src: |
|
data_dict['pixel_values_src'] = torch.stack(pixel_values_src) |
|
|
|
return {'data': data_dict, 'data_samples': None} |
|
|
|
|
|
def collate_func_und(instances, pad_index=151645): |
|
input_ids_list, labels_list, pixel_values_list = [], [], [] |
|
|
|
for sample in instances: |
|
input_ids_list.append(torch.LongTensor(sample['input_ids'])) |
|
labels_list.append(torch.LongTensor(sample['labels'])) |
|
|
|
if 'pixel_values' in sample: |
|
pixel_values_list.append(sample['pixel_values']) |
|
|
|
ori_length = [len(input_ids_) for input_ids_ in input_ids_list] |
|
|
|
if len(instances) > 1: |
|
input_ids = pad_sequence( |
|
input_ids_list, batch_first=True, padding_value=pad_index) |
|
labels = pad_sequence( |
|
labels_list, batch_first=True, padding_value=IGNORE_INDEX) |
|
else: |
|
input_ids = torch.stack(input_ids_list) |
|
labels = torch.stack(labels_list) |
|
|
|
attention_mask = torch.zeros_like(input_ids).bool() |
|
for i, length in enumerate(ori_length): |
|
attention_mask[i, :length] = True |
|
|
|
data_dict = { |
|
'input_ids': input_ids, |
|
'attention_mask': attention_mask, |
|
'labels': labels, |
|
'pixel_values': torch.stack(pixel_values_list) if len(pixel_values_list) > 0 else None |
|
} |
|
|
|
return {'data': data_dict, 'data_samples': None} |
|
|
|
|
|
class CollateConcat(object): |
|
def __init__(self, collate_fns, keys): |
|
self.keys = keys |
|
self.collate_fns = {} |
|
for key, collate_fn in zip(keys, collate_fns): |
|
func = collate_fn.pop('type') |
|
self.collate_fns[key] = partial(func, **collate_fn) |
|
|
|
def __call__(self, data_samples): |
|
data_samples = [data_sample for data_sample in data_samples if len(data_sample) > 0] |
|
data_dict = {} |
|
key = data_samples[0]['type'] |
|
data_dict[key] = self.collate_fns[key](data_samples)['data'] |
|
|
|
return {'data': data_dict, 'data_samples': None} |
|
|