YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

SPG: Sequential Policy Gradient: Lightweight Reinforcement Learning for Model Performance

🚀 If you're using Jupyter or Colab, you can follow the demo and run it on a single GPU: Open In Colab

Model Zoo: Adaptive Hyperparameter Optimization (HPO) via SPG Algorithm

Table 1: Performance of pre-trained vs. SPG-retrained models on ImageNet-1K

Model SPG # Params Acc@1 (%) Acc@5 (%) Weights Command to reproduce
MobileNet-V2 3.5 M 71.878 90.286 Recipe
MobileNet-V2 ✅HPO 3.5 M 72.104 90.316 run.sh
MobileNet-V2 ✅NAS 3.5 M 72.208 90.822 run.sh
ResNet-50 25.6 M 76.130 92.862 Recipe
ResNet-50 ✅HPO 25.6 M 77.234 93.322 run.sh
ResNet-50 ✅NAS 25.6 M 80.970 95.481 run.sh
EfficientNet-V2-M 54.1 M 85.112 97.156 Recipe
EfficientNet-V2-M ✅HPO 54.1 M 85.218 97.208 run.sh
EfficientNet-V2-M ✅NAS 54.1 M 85.347 97.424 run.sh
ViT-B16 86.6 M 81.072 95.318 Recipe
ViT-B16 ✅HPO 86.6 M 81.092 95.304 run.sh
ViT-B16 ✅NAS 86.6 M 81.114 95.320 run.sh

Table 2: Performance of pre-trained vs. SPG-retrained models. All models are evaluated a subset of COCO val2017, on the 21 categories that are present in the Pascal VOC dataset.

Model SPG # Params mIoU (%) pixelwise Acc (%) Weights Command to reproduce
FCN-ResNet50 35.3 M 60.5 91.4 Recipe
FCN-ResNet50 ✅HPO 35.3 M 60.9 91.6 run.sh
FCN-ResNet50 ✅NAS 35.3 M 61.2 91.7 run.sh
FCN-ResNet101 54.3 M 63.7 91.9 Recipe
FCN-ResNet101 ✅HPO 54.3 M 64.3 91.9 run.sh
FCN-ResNet101 ✅NAS 54.3 M 64.6 92.0 run.sh
DeepLabV3-ResNet50 42.0 M 66.4 92.4 Recipe
DeepLabV3-ResNet50 ✅HPO 42.0 M 66.6 92.5 run.sh
DeepLabV3-ResNet50 ✅NAS 42.0 M 66.8 92.6 run.sh
DeepLabV3-ResNet101 61.0 M 67.4 92.4 Recipe
DeepLabV3-ResNet101 ✅HPO 61.0 M 67.8 92.5 run.sh
DeepLabV3-ResNet101 ✅NAS 61.0 M 68.1 92.8 run.sh

Table 3: Performance comparison of fine-tuned vs. SPG-retrained models across NLP and speech benchmarks.

  • GLUE (Text classification: BERT on CoLA, SST-2, MRPC, QQP, QNLI, and RTE task)
  • SQuAD (Question answering: BERT)
  • SUPERB (Speech classification: Wav2Vec2 for Audio Classification (AC))
Task SPG Metric Type Performance (%) Weights Command to reproduce
CoLA Matthews coor 56.53 Recipe
CoLA ✅HPO Matthews coor 62.13 run.sh
CoLA ✅NAS Matthews coor 63.02 run.sh
SST-2 Accuracy 92.32 Recipe
SST-2 ✅HPO Accuracy 92.54 run.sh
SST-2 ✅NAS Accuracy 92.75 run.sh
MRPC F1/Accuracy 88.85/84.09 Recipe
MRPC ✅HPO F1/Accuracy 91.10/87.25 run.sh
MRPC ✅NAS F1/Accuracy 91.32/87.65 run.sh
QQP F1/Accuracy 87.49/90.71 Recipe
QQP ✅HPO F1/Accuracy 89.72/90.88 run.sh
QQP ✅NAS F1/Accuracy 89.88/91.03 run.sh
QNLI Accuracy 90.66 Recipe
QNLI ✅HPO Accuracy 91.10 run.sh
QNLI ✅NAS Accuracy 91.27 run.sh
RTE Accuracy 65.70 Recipe
RTE ✅HPO Accuracy 72.56 run.sh
RTE ✅NAS Accuracy 73.13 run.sh
Q/A* F1/Extra match 88.52/81.22 Recipe
Q/A* ✅HPO F1/Extra match 88.67/81.51 run.sh
Q/A* ✅NAS F1/Extra match 88.79/81.68 run.sh
AC† Accuracy 98.26 Recipe
AC† ✅HPO Accuracy 98.31 run.sh
AC† ✅NAS Accuracy 98.37 run.sh

Table 4: Performance of SFT vs. SPG-retrained models on GSM8K

Model SPG score Weights Command to reproduce
Gemma-2-2B-it 49.66 run.sh
Gemma-2-2B-it 52.31 run.sh
Qwen-2.5-0.5B-Instruct 39.12 run.sh
Qwen-2.5-0.5B-Instruct 41.70 run.sh
Qwen-2.5-1.5B-Instruct 58.68 run.sh
Qwen-2.5-1.5B-Instruct 59.12 run.sh

Requirements

  1. Install torch>=2.0.0+cu118.
  2. To install other pip packages:
        cd examples
        pip install -r requirements.txt
    
  3. Prepare the ImageNet dataset manually and place it in /path/to/imagenet. For image classification examples, pass the argument --data-path=/path/to/imagenet to the training script. The extracted dataset directory should follow this structure:
    /path/to/imagenet/:
        train/:
            n01440764: 
                n01440764_18.JPEG ...
            n01443537:
                n01443537_2.JPEG ...
        val/:
            n01440764:
                ILSVRC2012_val_00000293.JPEG ...
            n01443537:
                ILSVRC2012_val_00000236.JPEG ...
    
  4. Prepare the MS-COCO 2017 dataset manually and place it in /path/to/coco. For semantic segmentation examples, pass the argument --data-path=/path/to/coco to the training script. The extracted dataset directory should follow this structure:
    /path/to/coco/:
        annotations:
            many_json_files.json ...
        train2017:
            000000000009.jpg ...
        val2017:
            000000000139.jpg ...
    
  5. Prepare the GSM8K dataset manually and place it in /path/to/gsm8k. For language modeling examples, pass the argument --data-path=/path/to/gsm8k to the training script. The extracted dataset directory should follow this structure:
    /path/to/gsm8k/:
        train.parquet
        test.parquet
    
  6. For 🗣️ Keyword Spotting subset, Common Language, SQuAD, Common Voice, GLUE and WMT datasets, manual downloading is not required — they will be automatically loaded via the Hugging Face Datasets library when running our audio-classification, question-answering, speech-recognition, text-classification, or translation examples.
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support