|
import copy |
|
import logging |
|
import inspect |
|
|
|
from torch.utils.data import DataLoader |
|
from functools import partial |
|
from typing import Callable, Dict, List, Optional, Union |
|
|
|
from mmengine.logging import print_log |
|
from mmengine.dist import get_rank |
|
from mmengine.dataset import worker_init_fn as default_worker_init_fn |
|
from mmengine.utils import digit_version |
|
from mmengine.utils.dl_utils import TORCH_VERSION |
|
from mmengine.runner import FlexibleRunner |
|
from mmengine.registry import ( |
|
DATA_SAMPLERS, |
|
DATASETS, |
|
FUNCTIONS, |
|
) |
|
from xtuner.registry import BUILDER |
|
|
|
|
|
def clean_concatdataset_fields(cfg): |
|
""" |
|
递归清除所有 ConcatDataset 配置中的非法字段(如 image_size) |
|
""" |
|
if isinstance(cfg, dict): |
|
|
|
if cfg.get('type') == "ConcatDataset": |
|
for key in ['image_size']: |
|
if key in cfg: |
|
del cfg[key] |
|
|
|
|
|
for k, v in cfg.items(): |
|
clean_concatdataset_fields(v) |
|
|
|
elif isinstance(cfg, list): |
|
for item in cfg: |
|
clean_concatdataset_fields(item) |
|
|
|
return cfg |
|
|
|
|
|
|
|
class CustomRunner(FlexibleRunner): |
|
def __init__( |
|
self, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
|
|
@staticmethod |
|
def build_dataloader( |
|
dataloader: Union[DataLoader, Dict], |
|
seed: Optional[int] = None, |
|
diff_rank_seed: bool = False, |
|
) -> DataLoader: |
|
"""Build dataloader. |
|
|
|
The method builds three components: |
|
|
|
- Dataset |
|
- Sampler |
|
- Dataloader |
|
|
|
An example of ``dataloader``:: |
|
|
|
dataloader = dict( |
|
dataset=dict(type='ToyDataset'), |
|
sampler=dict(type='DefaultSampler', shuffle=True), |
|
batch_size=1, |
|
num_workers=9 |
|
) |
|
|
|
Args: |
|
dataloader (DataLoader or dict): A Dataloader object or a dict to |
|
build Dataloader object. If ``dataloader`` is a Dataloader |
|
object, just returns itself. |
|
seed (int, optional): Random seed. Defaults to None. |
|
diff_rank_seed (bool): Whether or not set different seeds to |
|
different ranks. If True, the seed passed to sampler is set |
|
to None, in order to synchronize the seeds used in samplers |
|
across different ranks. Defaults to False. |
|
|
|
Returns: |
|
Dataloader: DataLoader build from ``dataloader_cfg``. |
|
""" |
|
if isinstance(dataloader, DataLoader): |
|
return dataloader |
|
|
|
dataloader_cfg = copy.deepcopy(dataloader) |
|
|
|
clean_concatdataset_fields(dataloader_cfg) |
|
|
|
|
|
dataset_cfg = dataloader_cfg.pop('dataset') |
|
if isinstance(dataset_cfg, dict): |
|
dataset = DATASETS.build(dataset_cfg) |
|
if hasattr(dataset, 'full_init'): |
|
dataset.full_init() |
|
else: |
|
|
|
|
|
dataset = dataset_cfg |
|
|
|
|
|
sampler_cfg = dataloader_cfg.pop('sampler') |
|
if isinstance(sampler_cfg, dict): |
|
sampler_seed = None if diff_rank_seed else seed |
|
sampler = DATA_SAMPLERS.build( |
|
sampler_cfg, |
|
default_args=dict(dataset=dataset, seed=sampler_seed)) |
|
else: |
|
|
|
|
|
sampler = sampler_cfg |
|
|
|
|
|
batch_sampler_cfg = dataloader_cfg.pop('batch_sampler', None) |
|
if batch_sampler_cfg is None: |
|
batch_sampler = None |
|
elif isinstance(batch_sampler_cfg, dict): |
|
batch_sampler = DATA_SAMPLERS.build( |
|
batch_sampler_cfg, |
|
default_args=dict( |
|
dataset=dataset, |
|
sampler=sampler, |
|
batch_size=dataloader_cfg.pop('batch_size'))) |
|
else: |
|
|
|
|
|
batch_sampler = batch_sampler_cfg |
|
|
|
|
|
init_fn: Optional[partial] |
|
if 'worker_init_fn' in dataloader_cfg: |
|
worker_init_fn_cfg = dataloader_cfg.pop('worker_init_fn') |
|
worker_init_fn_type = worker_init_fn_cfg.pop('type') |
|
worker_init_fn = FUNCTIONS.get(worker_init_fn_type) |
|
assert callable(worker_init_fn) |
|
init_fn = partial(worker_init_fn, |
|
**worker_init_fn_cfg) |
|
else: |
|
if seed is not None: |
|
disable_subprocess_warning = dataloader_cfg.pop( |
|
'disable_subprocess_warning', False) |
|
assert isinstance(disable_subprocess_warning, bool), ( |
|
'disable_subprocess_warning should be a bool, but got ' |
|
f'{type(disable_subprocess_warning)}') |
|
init_fn = partial( |
|
default_worker_init_fn, |
|
num_workers=dataloader_cfg.get('num_workers'), |
|
rank=get_rank(), |
|
seed=seed, |
|
disable_subprocess_warning=disable_subprocess_warning) |
|
else: |
|
init_fn = None |
|
|
|
|
|
if ('persistent_workers' in dataloader_cfg |
|
and digit_version(TORCH_VERSION) < digit_version('1.7.0')): |
|
print_log( |
|
'`persistent_workers` is only available when ' |
|
'pytorch version >= 1.7', |
|
logger='current', |
|
level=logging.WARNING) |
|
dataloader_cfg.pop('persistent_workers') |
|
|
|
|
|
|
|
|
|
|
|
|
|
collate_fn_cfg = dataloader_cfg.pop('collate_fn', |
|
dict(type='pseudo_collate')) |
|
if isinstance(collate_fn_cfg, dict): |
|
collate_fn_type = collate_fn_cfg.pop('type') |
|
if isinstance(collate_fn_type, str): |
|
collate_fn = FUNCTIONS.get(collate_fn_type) |
|
elif inspect.isclass(collate_fn_type): |
|
collate_fn_cfg['type'] = collate_fn_type |
|
collate_fn = BUILDER.build(collate_fn_cfg) |
|
else: |
|
collate_fn = collate_fn_type |
|
if not inspect.isclass(collate_fn_type): |
|
collate_fn = partial(collate_fn, **collate_fn_cfg) |
|
elif callable(collate_fn_cfg): |
|
collate_fn = collate_fn_cfg |
|
else: |
|
raise TypeError( |
|
'collate_fn should be a dict or callable object, but got ' |
|
f'{collate_fn_cfg}') |
|
data_loader = DataLoader( |
|
dataset=dataset, |
|
sampler=sampler if batch_sampler is None else None, |
|
batch_sampler=batch_sampler, |
|
collate_fn=collate_fn, |
|
worker_init_fn=init_fn, |
|
**dataloader_cfg) |
|
|
|
return data_loader |
|
|