import glob import numpy as np import os import shutil from src.logger.logger import _logger def to_filelist(args, mode="train"): if mode == "train": flist = args.data_train elif mode == "val": flist = args.data_val elif mode == "test": flist = args.data_test else: raise NotImplementedError("Invalid mode %s" % mode) print(flist) # keyword-based: 'a:/path/to/a b:/path/to/b' file_dict = {} for f in flist: if ":" in f: name, fp = f.split(":") else: name, fp = "_", f files = glob.glob(fp) if name in file_dict: file_dict[name] += files else: file_dict[name] = files # sort files for name, files in file_dict.items(): file_dict[name] = sorted(files) if args.local_rank is not None: if mode == "train": local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) new_file_dict = {} for name, files in file_dict.items(): new_files = files[args.local_rank :: local_world_size] assert len(new_files) > 0 np.random.shuffle(new_files) new_file_dict[name] = new_files file_dict = new_file_dict if args.copy_inputs: import tempfile tmpdir = tempfile.mkdtemp() if os.path.exists(tmpdir): shutil.rmtree(tmpdir) new_file_dict = {name: [] for name in file_dict} for name, files in file_dict.items(): for src in files: dest = os.path.join(tmpdir, src.lstrip("/")) if not os.path.exists(os.path.dirname(dest)): os.makedirs(os.path.dirname(dest), exist_ok=True) shutil.copy2(src, dest) _logger.info("Copied file %s to %s" % (src, dest)) new_file_dict[name].append(dest) if len(files) != len(new_file_dict[name]): _logger.error( "Only %d/%d files copied for %s file group %s", len(new_file_dict[name]), len(files), mode, name, ) file_dict = new_file_dict filelist = sum(file_dict.values(), []) assert len(filelist) == len(set(filelist)) return file_dict, filelist def clear_empty_paths(dir): # clear the dirs in this folder that are empty (i.e. don't have any files or folders in them) for f in os.listdir(dir): if not os.path.isdir(os.path.join(dir, f)): continue if not os.listdir(os.path.join(dir, f)): shutil.rmtree(os.path.join(dir, f)) _logger.info("Removed empty path %s" % f) import io import torch import pickle class CPU_Unpickler(pickle.Unpickler): def find_class(self, module, name): if module == 'torch.storage' and name == '_load_from_bytes': return lambda b: torch.load(io.BytesIO(b), map_location='cpu') else: return super().find_class(module, name)