Spaces:
dreroc
/
Running on Zero

UniPic / src /runners /custom_runner.py
yichenchenchen's picture
Upload 25 files
ea88892 verified
import copy
import logging
import inspect
from torch.utils.data import DataLoader
from functools import partial
from typing import Callable, Dict, List, Optional, Union
from mmengine.logging import print_log
from mmengine.dist import get_rank
from mmengine.dataset import worker_init_fn as default_worker_init_fn
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from mmengine.runner import FlexibleRunner
from mmengine.registry import (
DATA_SAMPLERS,
DATASETS,
FUNCTIONS,
)
from xtuner.registry import BUILDER
def clean_concatdataset_fields(cfg):
"""
递归清除所有 ConcatDataset 配置中的非法字段(如 image_size)
"""
if isinstance(cfg, dict):
# 如果是 ConcatDataset 层,清除非法字段
if cfg.get('type') == "ConcatDataset":
for key in ['image_size']:
if key in cfg:
del cfg[key]
# 递归处理子字段
for k, v in cfg.items():
clean_concatdataset_fields(v)
elif isinstance(cfg, list):
for item in cfg:
clean_concatdataset_fields(item)
return cfg
class CustomRunner(FlexibleRunner):
def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)
@staticmethod
def build_dataloader(
dataloader: Union[DataLoader, Dict],
seed: Optional[int] = None,
diff_rank_seed: bool = False,
) -> DataLoader:
"""Build dataloader.
The method builds three components:
- Dataset
- Sampler
- Dataloader
An example of ``dataloader``::
dataloader = dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=1,
num_workers=9
)
Args:
dataloader (DataLoader or dict): A Dataloader object or a dict to
build Dataloader object. If ``dataloader`` is a Dataloader
object, just returns itself.
seed (int, optional): Random seed. Defaults to None.
diff_rank_seed (bool): Whether or not set different seeds to
different ranks. If True, the seed passed to sampler is set
to None, in order to synchronize the seeds used in samplers
across different ranks. Defaults to False.
Returns:
Dataloader: DataLoader build from ``dataloader_cfg``.
"""
if isinstance(dataloader, DataLoader):
return dataloader
dataloader_cfg = copy.deepcopy(dataloader)
clean_concatdataset_fields(dataloader_cfg)
# build dataset
dataset_cfg = dataloader_cfg.pop('dataset')
if isinstance(dataset_cfg, dict):
dataset = DATASETS.build(dataset_cfg)
if hasattr(dataset, 'full_init'):
dataset.full_init()
else:
# fallback to raise error in dataloader
# if `dataset_cfg` is not a valid type
dataset = dataset_cfg
# build sampler
sampler_cfg = dataloader_cfg.pop('sampler')
if isinstance(sampler_cfg, dict):
sampler_seed = None if diff_rank_seed else seed
sampler = DATA_SAMPLERS.build(
sampler_cfg,
default_args=dict(dataset=dataset, seed=sampler_seed))
else:
# fallback to raise error in dataloader
# if `sampler_cfg` is not a valid type
sampler = sampler_cfg
# build batch sampler
batch_sampler_cfg = dataloader_cfg.pop('batch_sampler', None)
if batch_sampler_cfg is None:
batch_sampler = None
elif isinstance(batch_sampler_cfg, dict):
batch_sampler = DATA_SAMPLERS.build(
batch_sampler_cfg,
default_args=dict(
dataset=dataset,
sampler=sampler,
batch_size=dataloader_cfg.pop('batch_size')))
else:
# fallback to raise error in dataloader
# if `batch_sampler_cfg` is not a valid type
batch_sampler = batch_sampler_cfg
# build dataloader
init_fn: Optional[partial]
if 'worker_init_fn' in dataloader_cfg:
worker_init_fn_cfg = dataloader_cfg.pop('worker_init_fn')
worker_init_fn_type = worker_init_fn_cfg.pop('type')
worker_init_fn = FUNCTIONS.get(worker_init_fn_type)
assert callable(worker_init_fn)
init_fn = partial(worker_init_fn,
**worker_init_fn_cfg) # type: ignore
else:
if seed is not None:
disable_subprocess_warning = dataloader_cfg.pop(
'disable_subprocess_warning', False)
assert isinstance(disable_subprocess_warning, bool), (
'disable_subprocess_warning should be a bool, but got '
f'{type(disable_subprocess_warning)}')
init_fn = partial(
default_worker_init_fn,
num_workers=dataloader_cfg.get('num_workers'),
rank=get_rank(),
seed=seed,
disable_subprocess_warning=disable_subprocess_warning)
else:
init_fn = None
# `persistent_workers` requires pytorch version >= 1.7
if ('persistent_workers' in dataloader_cfg
and digit_version(TORCH_VERSION) < digit_version('1.7.0')):
print_log(
'`persistent_workers` is only available when '
'pytorch version >= 1.7',
logger='current',
level=logging.WARNING)
dataloader_cfg.pop('persistent_workers')
# The default behavior of `collat_fn` in dataloader is to
# merge a list of samples to form a mini-batch of Tensor(s).
# However, in mmengine, if `collate_fn` is not defined in
# dataloader_cfg, `pseudo_collate` will only convert the list of
# samples into a dict without stacking the batch tensor.
collate_fn_cfg = dataloader_cfg.pop('collate_fn',
dict(type='pseudo_collate'))
if isinstance(collate_fn_cfg, dict):
collate_fn_type = collate_fn_cfg.pop('type')
if isinstance(collate_fn_type, str):
collate_fn = FUNCTIONS.get(collate_fn_type)
elif inspect.isclass(collate_fn_type):
collate_fn_cfg['type'] = collate_fn_type
collate_fn = BUILDER.build(collate_fn_cfg)
else:
collate_fn = collate_fn_type
if not inspect.isclass(collate_fn_type):
collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore
elif callable(collate_fn_cfg):
collate_fn = collate_fn_cfg
else:
raise TypeError(
'collate_fn should be a dict or callable object, but got '
f'{collate_fn_cfg}')
data_loader = DataLoader(
dataset=dataset,
sampler=sampler if batch_sampler is None else None,
batch_sampler=batch_sampler,
collate_fn=collate_fn,
worker_init_fn=init_fn,
**dataloader_cfg)
return data_loader