import copy import logging import inspect from torch.utils.data import DataLoader from functools import partial from typing import Callable, Dict, List, Optional, Union from mmengine.logging import print_log from mmengine.dist import get_rank from mmengine.dataset import worker_init_fn as default_worker_init_fn from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION from mmengine.runner import FlexibleRunner from mmengine.registry import ( DATA_SAMPLERS, DATASETS, FUNCTIONS, ) from xtuner.registry import BUILDER def clean_concatdataset_fields(cfg): """ 递归清除所有 ConcatDataset 配置中的非法字段(如 image_size) """ if isinstance(cfg, dict): # 如果是 ConcatDataset 层,清除非法字段 if cfg.get('type') == "ConcatDataset": for key in ['image_size']: if key in cfg: del cfg[key] # 递归处理子字段 for k, v in cfg.items(): clean_concatdataset_fields(v) elif isinstance(cfg, list): for item in cfg: clean_concatdataset_fields(item) return cfg class CustomRunner(FlexibleRunner): def __init__( self, **kwargs, ): super().__init__(**kwargs) @staticmethod def build_dataloader( dataloader: Union[DataLoader, Dict], seed: Optional[int] = None, diff_rank_seed: bool = False, ) -> DataLoader: """Build dataloader. The method builds three components: - Dataset - Sampler - Dataloader An example of ``dataloader``:: dataloader = dict( dataset=dict(type='ToyDataset'), sampler=dict(type='DefaultSampler', shuffle=True), batch_size=1, num_workers=9 ) Args: dataloader (DataLoader or dict): A Dataloader object or a dict to build Dataloader object. If ``dataloader`` is a Dataloader object, just returns itself. seed (int, optional): Random seed. Defaults to None. diff_rank_seed (bool): Whether or not set different seeds to different ranks. If True, the seed passed to sampler is set to None, in order to synchronize the seeds used in samplers across different ranks. Defaults to False. Returns: Dataloader: DataLoader build from ``dataloader_cfg``. """ if isinstance(dataloader, DataLoader): return dataloader dataloader_cfg = copy.deepcopy(dataloader) clean_concatdataset_fields(dataloader_cfg) # build dataset dataset_cfg = dataloader_cfg.pop('dataset') if isinstance(dataset_cfg, dict): dataset = DATASETS.build(dataset_cfg) if hasattr(dataset, 'full_init'): dataset.full_init() else: # fallback to raise error in dataloader # if `dataset_cfg` is not a valid type dataset = dataset_cfg # build sampler sampler_cfg = dataloader_cfg.pop('sampler') if isinstance(sampler_cfg, dict): sampler_seed = None if diff_rank_seed else seed sampler = DATA_SAMPLERS.build( sampler_cfg, default_args=dict(dataset=dataset, seed=sampler_seed)) else: # fallback to raise error in dataloader # if `sampler_cfg` is not a valid type sampler = sampler_cfg # build batch sampler batch_sampler_cfg = dataloader_cfg.pop('batch_sampler', None) if batch_sampler_cfg is None: batch_sampler = None elif isinstance(batch_sampler_cfg, dict): batch_sampler = DATA_SAMPLERS.build( batch_sampler_cfg, default_args=dict( dataset=dataset, sampler=sampler, batch_size=dataloader_cfg.pop('batch_size'))) else: # fallback to raise error in dataloader # if `batch_sampler_cfg` is not a valid type batch_sampler = batch_sampler_cfg # build dataloader init_fn: Optional[partial] if 'worker_init_fn' in dataloader_cfg: worker_init_fn_cfg = dataloader_cfg.pop('worker_init_fn') worker_init_fn_type = worker_init_fn_cfg.pop('type') worker_init_fn = FUNCTIONS.get(worker_init_fn_type) assert callable(worker_init_fn) init_fn = partial(worker_init_fn, **worker_init_fn_cfg) # type: ignore else: if seed is not None: disable_subprocess_warning = dataloader_cfg.pop( 'disable_subprocess_warning', False) assert isinstance(disable_subprocess_warning, bool), ( 'disable_subprocess_warning should be a bool, but got ' f'{type(disable_subprocess_warning)}') init_fn = partial( default_worker_init_fn, num_workers=dataloader_cfg.get('num_workers'), rank=get_rank(), seed=seed, disable_subprocess_warning=disable_subprocess_warning) else: init_fn = None # `persistent_workers` requires pytorch version >= 1.7 if ('persistent_workers' in dataloader_cfg and digit_version(TORCH_VERSION) < digit_version('1.7.0')): print_log( '`persistent_workers` is only available when ' 'pytorch version >= 1.7', logger='current', level=logging.WARNING) dataloader_cfg.pop('persistent_workers') # The default behavior of `collat_fn` in dataloader is to # merge a list of samples to form a mini-batch of Tensor(s). # However, in mmengine, if `collate_fn` is not defined in # dataloader_cfg, `pseudo_collate` will only convert the list of # samples into a dict without stacking the batch tensor. collate_fn_cfg = dataloader_cfg.pop('collate_fn', dict(type='pseudo_collate')) if isinstance(collate_fn_cfg, dict): collate_fn_type = collate_fn_cfg.pop('type') if isinstance(collate_fn_type, str): collate_fn = FUNCTIONS.get(collate_fn_type) elif inspect.isclass(collate_fn_type): collate_fn_cfg['type'] = collate_fn_type collate_fn = BUILDER.build(collate_fn_cfg) else: collate_fn = collate_fn_type if not inspect.isclass(collate_fn_type): collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore elif callable(collate_fn_cfg): collate_fn = collate_fn_cfg else: raise TypeError( 'collate_fn should be a dict or callable object, but got ' f'{collate_fn_cfg}') data_loader = DataLoader( dataset=dataset, sampler=sampler if batch_sampler is None else None, batch_sampler=batch_sampler, collate_fn=collate_fn, worker_init_fn=init_fn, **dataloader_cfg) return data_loader