SPG / README.md
nielsr's picture
nielsr HF Staff
Add metadata: license, pipeline tag, library name, and link to paper and Github repo
906b10a verified
|
raw
history blame
21.6 kB
metadata
license: mit
library_name: transformers
pipeline_tag: image-classification

SPG: Sequential Policy Gradient for Adaptive Hyperparameter Optimization

This repository contains the models described in the paper Sequential Policy Gradient for Adaptive Hyperparameter Optimization.

Project page Github repository

πŸš€ If you're using Jupyter or Colab, you can follow the demo and run it on a single GPU: Open In Colab](https://colab.research.google.com/#fileId=https%3A//huggingface.co/UniversalAlgorithmic/SPG/blob/main/demo_nas.ipynb)

Model Zoo: Adaptive Hyperparameter Optimization (HPO) via SPG Algorithm

Table 1: Performance of pre-trained vs. SPG-retrained models on ImageNet-1K

Model SPG # Params Acc@1 (%) Acc@5 (%) Weights Command to reproduce
MobileNet-V2 ❌ 3.5 M 71.878 90.286 Recipe
MobileNet-V2 βœ… 3.5 M 72.104 90.316 examples/image-classification/run.sh
ResNet-50 ❌ 25.6 M 76.130 92.862 Recipe
ResNet-50 βœ… 25.6 M 77.234 93.322 examples/image-classification/run.sh
EfficientNet-V2-M ❌ 54.1 M 85.112 97.156 Recipe
EfficientNet-V2-M βœ… 54.1 M 85.218 97.208 examples/image-classification/run.sh
ViT-B16 ❌ 86.6 M 81.072 95.318 Recipe
ViT-B16 βœ… 86.6 M 81.092 95.304 examples/image-classification/run.sh

Table 2: Performance of pre-trained vs. SPG-retrained models. All models are evaluated a subset of COCO val2017, on the 21/20 categories that are present in the Pascal VOC dataset.

⚠️ All model reported on TorchVision (with weight COCO_WITH_VOC_LABELS_V1) were benchmarked using only 20 categories. Researchers should first download the pre-trained model from TorchVision and conduct re-evaluation under the 21-categories (including "background") framework.

Model SPG # Params mIoU (%) pixelwise Acc (%) Weights Command to reproduce
FCN-ResNet50 ❌ 35.3 M 58.9/60.5 90.9/91.4 Recipe
FCN-ResNet50 βœ… 35.3 M 59.4/60.9 90.9/91.6 examples/semantic-segmentation/run.sh
FCN-ResNet101 ❌ 54.3 M 62.2/63.7 91.1/91.9 Recipe
FCN-ResNet101 βœ… 54.3 M 62.4/64.3 91.1/91.9 examples/semantic-segmentation/run.sh
DeepLabV3-ResNet50 ❌ 42.0 M 63.8/66.4 91.5/92.4 Recipe
DeepLabV3-ResNet50 βœ… 42.0 M 64.2/66.6 91.6/92.5 examples/semantic-segmentation/run.sh
DeepLabV3-ResNet101 ❌ 61.0 M 65.3/67.4 91.7/92.4 Recipe
DeepLabV3-ResNet101 βœ… 61.0 M 65.7/67.8 91.8/92.5 examples/semantic-segmentation/run.sh

Table 3: Performance comparison of fine-tuned vs. SPG-retrained models across NLP and speech benchmarks.

  • GLUE (Text classification: BERT on CoLA, SST-2, MRPC, QQP, QNLI, and RTE task)
  • SQuAD (Question answering: BERT)
  • SUPERB (Speech classification: Wav2Vec2 for Audio Classification (AC))
Task SPG Metric Type Performance (%) Weights Command to reproduce
CoLA ❌ Matthews coor 56.53 Recipe
CoLA βœ… Matthews coor 62.13 examples/text-classification/run.sh
SST-2 ❌ Accuracy 92.32 Recipe
SST-2 βœ… Accuracy 92.54 examples/text-classification/run.sh
MRPC ❌ F1/Accuracy 88.85/84.09 Recipe
MRPC βœ… F1/Accuracy 91.10/87.25 examples/text-classification/run.sh
QQP ❌ F1/Accuracy 87.49/90.71 Recipe
QQP βœ… F1/Accuracy 89.72/90.88 examples/text-classification/run.sh
QNLI ❌ Accuracy 90.66 Recipe
QNLI βœ… Accuracy 91.10 examples/text-classification/run.sh
RTE ❌ Accuracy 65.70 Recipe
RTE βœ… Accuracy 72.56 examples/text-classification/run.sh
Q/A* ❌ F1/Extra match 88.52/81.22 Recipe
Q/A* βœ… F1/Extra match 88.67/81.51 examples/question-answering/run.sh
AC† ❌ Accuracy 98.26 Recipe
AC† βœ… Accuracy 98.31 examples/audio-answering/run.sh

Model Zoo: Neural Architecture Search (NAS) via SPG Algorithm

Table 4: Performance of pre-trained vs. SPG-retrained models on ImageNet-1K Depending on the base model, we explore the following architectures:

  • ResNet-18: ResNet-18, ResNet-27, ResNet-36, ResNet-45
  • ResNet-34: ResNet-34, ResNet-40, ResNet-46, ResNet-52
  • ResNet-50: ResNet-50, ResNet-53, ResNet-56, ResNet-59

⚠️Our SPG differs from most NAS algorithms, which typically use a gating network for architecture selection. In contrast, we neither employ a gating network nor a proxy network. Instead, after policy optimization, we keep only the base architecture (ResNet-18, ResNet-34, and ResNet-50) and remove all others (ResNet-27/36/45, ResNet-40/46/52, and ResNet-53/56/59).

Model SPG # Params Acc@1 (%) Acc@5 (%) Weights Command to reproduce
ResNet-18 ❌ 11.7M 69.758 89.078 Recipe
ResNet-18 βœ… 11.7M 70.092 89.314 examples/neural-architecture-search/run.sh
ResNet-34 ❌ 21.8M 73.314 91.420 Recipe
ResNet-34 βœ… 21.8M 73.900 93.536 examples/neural-architecture-search/run.sh
ResNet-50 ❌ 25.6 M 76.130 92.862 Recipe
ResNet-50 βœ… 25.6 M 77.234 93.322 examples/neural-architecture-search/run.sh

Requirements

  1. Install torch>=2.0.0+cu118.
  2. To install other pip packages:
        cd examples
        pip install -r requirements.txt
    
  3. Prepare the ImageNet dataset manually and place it in /path/to/imagenet. For image classification examples, pass the argument --data-path=/path/to/imagenet to the training script. The extracted dataset directory should follow this structure:
    /path/to/imagenet/:
        train/:
            n01440764: 
                n01440764_18.JPEG ...
            n01443537:
                n01443537_2.JPEG ...
        val/:
            n01440764:
                ILSVRC2012_val_00000293.JPEG ...
            n01443537:
                ILSVRC2012_val_00000236.JPEG ...
    
  4. Prepare the MS-COCO 2017 dataset manually and place it in /path/to/coco. For image classification examples, pass the argument --data-path=/path/to/coco to the training script. The extracted dataset directory should follow this structure:
    /path/to/coco/:
        annotations:
            many_json_files.json ...
        train2017:
            000000000009.jpg ...
        val2017:
            000000000139.jpg ...
    
  5. For πŸ—£οΈ Keyword Spotting subset, Common Language, SQuAD, Common Voice, GLUE and WMT datasets, manual downloading is not required β€” they will be automatically loaded via the Hugging Face Datasets library when running our audio-classification, question-answering, speech-recognition, text-classification, or translation examples.

Training

Retrain model on ImageNet-1K

We use training recipes similar to those in PyTorch Vision's classification reference to retrain MobileNet-V2, ResNet, EfficientNet-V2, and ViT with our SPG on ImageNet-1K. The following command can be used:

cd ./examples/image-classification

# MobileNet-V2
torchrun --nproc_per_node=4 train.py\
  --data-path /path/to/imagenet/\
  --model mobilenet_v2  --output-dir mobilenet_v2 --weights MobileNet_V2_Weights.IMAGENET1K_V1\
  --batch-size 192 --epochs 40 --lr 0.0004 --lr-step-size 10 --lr-gamma 0.5 --wd 0.00004 --apply-trp --trp-depths 1 --trp-p 0.15 --trp-lambdas 0.4 0.2 0.1

# ResNet-50
torchrun --nproc_per_node=4 train.py\
    --data-path /path/to/imagenet/\
    --model resnet50 --output-dir resnet50 --weights ResNet50_Weights.IMAGENET1K_V1\
    --batch-size 64 --epochs 40 --lr 0.0004 --lr-step-size 10 --lr-gamma 0.5 --print-freq 100\
    --apply-trp --trp-depths 1 --trp-p 0.2 --trp-lambdas 0.4 0.2 0.1

# EfficientNet-V2 M
torchrun --nproc_per_node=4 train.py \
  --data-path /path/to/imagenet/\
  --model efficientnet_v2_m --output-dir efficientnet_v2_m --weights EfficientNet_V2_M_Weights.IMAGENET1K_V1\
  --epochs 10 --batch-size 64 --lr 5e-9 --lr-scheduler cosineannealinglr --weight-decay 0.00002 \
  --lr-warmup-method constant --lr-warmup-epochs 8 --lr-warmup-decay 0. \
  --auto-augment ta_wide --random-erase 0.1 --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --norm-weight-decay 0.0 \
  --train-crop-size 384 --val-crop-size 480 --val-resize-size 480 --ra-sampler --ra-reps 4 --print-freq 100\
  --apply-trp --trp-depths 1 --trp-p 0.2 --trp-lambdas 0.4 0.2 0.1

# ViT-B-16
torchrun --nproc_per_node=4 train.py\
  --data-path /path/to/imagenet/\
  --model vit_b_16 --output-dir vit_b_16 --weights ViT_B_16_Weights.IMAGENET1K_V1\
  --epochs 5 --batch-size 196 --opt adamw --lr 5e-9 --lr-scheduler cosineannealinglr --wd 0.3\
  --lr-warmup-method constant --lr-warmup-epochs 3 --lr-warmup-decay 0. \
  --amp --label-smoothing 0.11 --mixup-alpha 0.2 --auto-augment ra --clip-grad-norm 1 --cutmix-alpha 1.0\
  --apply-trp --trp-depths 1 --trp-p 0.1 --trp-lambdas 0.4 0.2 0.1 --print-freq 100

Retrain model on MS-COCO 2017

We use training recipes similar to those in PyTorch Vision's segmentation reference to retrain FCN and DeepLab-V3 with our SPG on COCO dataset. The following command can be used:


cd ./examples/semantic-segmentation

# FCN-ResNet50
torchrun --nproc_per_node=4 train.py\
  --workers 4 --dataset coco --data-path /path/to/coco/\
  --model fcn_resnet50 --aux-loss --output-dir fcn_resnet50 --weights FCN_ResNet50