import argparse import glob import os import random from sklearn.model_selection import KFold def prepare_folds(cityscapes_path, output_dir, n_splits=3): """ Prepares k-fold cross-validation splits for the Cityscapes dataset. Args: cityscapes_path (str): Path to the root Cityscapes directory (containing leftImg8bit and gtFine). output_dir (str): Directory to save the split files. n_splits (int): Number of folds for cross-validation. """ leftimg8bit_path = os.path.join(cityscapes_path, "leftImg8bit") train_img_dir = os.path.join(leftimg8bit_path, "train") val_img_dir = os.path.join(leftimg8bit_path, "val") train_files = [] # Check if train_img_dir exists before listing its contents if not os.path.exists(train_img_dir): print(f"Error: Training image directory not found: {train_img_dir}") print( f"Please ensure '{cityscapes_path}' is the correct root and contains 'leftImg8bit/train'." ) return for city_folder in os.listdir(train_img_dir): city_path = os.path.join(train_img_dir, city_folder) if os.path.isdir(city_path): train_files.extend(glob.glob(os.path.join(city_path, "*.png"))) val_files = [] # Check if val_img_dir exists if not os.path.exists(val_img_dir): print(f"Error: Validation image directory not found: {val_img_dir}") print( f"Please ensure '{cityscapes_path}' is the correct root and contains 'leftImg8bit/val'." ) pass elif os.path.exists(val_img_dir): for city_folder in os.listdir(val_img_dir): city_path = os.path.join(val_img_dir, city_folder) if os.path.isdir(city_path): val_files.extend(glob.glob(os.path.join(city_path, "*.png"))) if not train_files and not val_files: print(f"Error: No image files found in {train_img_dir} or {val_img_dir}.") print("Please check your Cityscapes dataset structure and path.") return all_files = train_files + val_files # 关键修改:生成正确的相对路径格式 all_files_relative = [] for f in all_files: rel_path = os.path.relpath(f, leftimg8bit_path) # 去掉 _leftImg8bit.png 后缀 if rel_path.endswith("_leftImg8bit.png"): rel_path = rel_path[: -len("_leftImg8bit.png")] all_files_relative.append(rel_path) all_files_relative = sorted(all_files_relative) # Ensure consistent shuffling for reproducibility if needed random.seed(42) random.shuffle(all_files_relative) kf = KFold(n_splits=n_splits, shuffle=False) # Shuffle is already done os.makedirs(output_dir, exist_ok=True) for i, (train_index, val_index) in enumerate(kf.split(all_files_relative)): fold_train_files = [all_files_relative[k] for k in train_index] fold_val_files = [all_files_relative[k] for k in val_index] train_file_path = os.path.join(output_dir, f"fold_{i + 1}_train_split.txt") val_file_path = os.path.join(output_dir, f"fold_{i + 1}_val_split.txt") with open(train_file_path, "w") as f: for item in fold_train_files: f.write(f"{item}\n") # 修复:使用单个 \n with open(val_file_path, "w") as f: for item in fold_val_files: f.write(f"{item}\n") # 修复:使用单个 \n print(f"Fold {i + 1}: {len(fold_train_files)} train, {len(fold_val_files)} val") # 添加调试信息 print("Sample train files:") for sample in fold_train_files[:3]: print(f" {sample}") print("Sample val files:") for sample in fold_val_files[:3]: print(f" {sample}") print(f"Split files saved to: {os.path.abspath(output_dir)}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Prepare Cityscapes k-fold splits.") parser.add_argument( "cityscapes_path", type=str, help="Absolute path to the Cityscapes dataset directory.", ) parser.add_argument( "--output_dir", type=str, default=None, help="Directory to save the split files. If not provided, a 'splits' folder will be created inside the cityscapes_path.", ) parser.add_argument( "--n_splits", type=int, default=3, help="Number of folds (default: 3)." ) args = parser.parse_args() if args.output_dir is None: effective_output_dir = os.path.join(args.cityscapes_path, "splits") else: effective_output_dir = args.output_dir abs_cityscapes_path = os.path.abspath(args.cityscapes_path) if not os.path.isdir(abs_cityscapes_path): print( f"Error: Cityscapes path not found or is not a directory: {abs_cityscapes_path}" ) exit(1) prepare_folds(abs_cityscapes_path, effective_output_dir, args.n_splits) script_name = os.path.basename(__file__) print("\nTo run this script again, for example:") print(f"python {script_name} /path/to/your/cityscapes") if args.output_dir is not None: print( f"python {script_name} /path/to/your/cityscapes --output_dir {args.output_dir}" ) print( "Replace '/path/to/your/cityscapes' with the actual path to your Cityscapes dataset." )