Upload 178 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- hpo-examples/audio-classification/ac/README.md +88 -0
- hpo-examples/audio-classification/ac/all_results.json +13 -0
- hpo-examples/audio-classification/ac/config.json +147 -0
- hpo-examples/audio-classification/ac/eval_results.json +8 -0
- hpo-examples/audio-classification/ac/model.safetensors +3 -0
- hpo-examples/audio-classification/ac/preprocessor_config.json +9 -0
- hpo-examples/audio-classification/ac/runs/May15_03-06-03_cs-Precision-7960-Tower/events.out.tfevents.1747292768.cs-Precision-7960-Tower.146737.0 +3 -0
- hpo-examples/audio-classification/ac/runs/May15_03-06-03_cs-Precision-7960-Tower/events.out.tfevents.1747293535.cs-Precision-7960-Tower.146737.1 +3 -0
- hpo-examples/audio-classification/ac/train_results.json +8 -0
- hpo-examples/audio-classification/ac/trainer_state.json +1598 -0
- hpo-examples/audio-classification/ac/training_args.bin +3 -0
- hpo-examples/audio-classification/requirements.txt +5 -0
- hpo-examples/audio-classification/run.sh +30 -0
- hpo-examples/audio-classification/run_audio_classification.py +462 -0
- hpo-examples/audio-classification/trplib.py +1181 -0
- hpo-examples/image-classification/__pycache__/presets.cpython-310.pyc +0 -0
- hpo-examples/image-classification/__pycache__/sampler.cpython-310.pyc +0 -0
- hpo-examples/image-classification/__pycache__/transforms.cpython-310.pyc +0 -0
- hpo-examples/image-classification/__pycache__/trplib.cpython-310.pyc +0 -0
- hpo-examples/image-classification/__pycache__/utils.cpython-310.pyc +0 -0
- hpo-examples/image-classification/efficientnet_v2_m/model_7.pth +3 -0
- hpo-examples/image-classification/mobilenetv2/model_32.pth +3 -0
- hpo-examples/image-classification/presets.py +71 -0
- hpo-examples/image-classification/resnet50/model_35.pth +3 -0
- hpo-examples/image-classification/run.sh +49 -0
- hpo-examples/image-classification/sampler.py +62 -0
- hpo-examples/image-classification/train.py +524 -0
- hpo-examples/image-classification/train_quantization.py +265 -0
- hpo-examples/image-classification/transforms.py +183 -0
- hpo-examples/image-classification/trplib.py +1181 -0
- hpo-examples/image-classification/utils.py +465 -0
- hpo-examples/image-classification/vit_b_16/model_4.pth +3 -0
- hpo-examples/question-answering/qa/README.md +55 -0
- hpo-examples/question-answering/qa/all_results.json +15 -0
- hpo-examples/question-answering/qa/config.json +26 -0
- hpo-examples/question-answering/qa/eval_nbest_predictions.json +3 -0
- hpo-examples/question-answering/qa/eval_predictions.json +0 -0
- hpo-examples/question-answering/qa/eval_results.json +9 -0
- hpo-examples/question-answering/qa/model.safetensors +3 -0
- hpo-examples/question-answering/qa/runs/May15_03-24-14_cs-Precision-7960-Tower/events.out.tfevents.1747293859.cs-Precision-7960-Tower.147971.0 +3 -0
- hpo-examples/question-answering/qa/runs/May15_03-24-14_cs-Precision-7960-Tower/events.out.tfevents.1747297197.cs-Precision-7960-Tower.147971.1 +3 -0
- hpo-examples/question-answering/qa/special_tokens_map.json +7 -0
- hpo-examples/question-answering/qa/tokenizer.json +0 -0
- hpo-examples/question-answering/qa/tokenizer_config.json +56 -0
- hpo-examples/question-answering/qa/train_results.json +9 -0
- hpo-examples/question-answering/qa/trainer_state.json +245 -0
- hpo-examples/question-answering/qa/training_args.bin +3 -0
- hpo-examples/question-answering/qa/vocab.txt +0 -0
- hpo-examples/question-answering/requirements.txt +4 -0
.gitattributes
CHANGED
|
@@ -37,3 +37,4 @@ qa/eval_nbest_predictions.json filter=lfs diff=lfs merge=lfs -text
|
|
| 37 |
qa/sequential-policy-gradient.pdf filter=lfs diff=lfs merge=lfs -text
|
| 38 |
sequential-policy-gradient.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
examples/question-answering/qa/eval_nbest_predictions.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 37 |
qa/sequential-policy-gradient.pdf filter=lfs diff=lfs merge=lfs -text
|
| 38 |
sequential-policy-gradient.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
examples/question-answering/qa/eval_nbest_predictions.json filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
hpo-examples/question-answering/qa/eval_nbest_predictions.json filter=lfs diff=lfs merge=lfs -text
|
hpo-examples/audio-classification/ac/README.md
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
license: apache-2.0
|
| 4 |
+
base_model: facebook/wav2vec2-base
|
| 5 |
+
tags:
|
| 6 |
+
- audio-classification
|
| 7 |
+
- generated_from_trainer
|
| 8 |
+
datasets:
|
| 9 |
+
- superb
|
| 10 |
+
metrics:
|
| 11 |
+
- accuracy
|
| 12 |
+
model-index:
|
| 13 |
+
- name: wav2vec2-base-ft-keyword-spotting
|
| 14 |
+
results:
|
| 15 |
+
- task:
|
| 16 |
+
name: Audio Classification
|
| 17 |
+
type: audio-classification
|
| 18 |
+
dataset:
|
| 19 |
+
name: superb
|
| 20 |
+
type: superb
|
| 21 |
+
config: ks
|
| 22 |
+
split: validation
|
| 23 |
+
args: ks
|
| 24 |
+
metrics:
|
| 25 |
+
- name: Accuracy
|
| 26 |
+
type: accuracy
|
| 27 |
+
value: 0.9826419535157399
|
| 28 |
+
---
|
| 29 |
+
|
| 30 |
+
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
| 31 |
+
should probably proofread and complete it, then remove this comment. -->
|
| 32 |
+
|
| 33 |
+
# wav2vec2-base-ft-keyword-spotting
|
| 34 |
+
|
| 35 |
+
This model is a fine-tuned version of [facebook/wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) on the superb dataset.
|
| 36 |
+
It achieves the following results on the evaluation set:
|
| 37 |
+
- Loss: 0.0954
|
| 38 |
+
- Accuracy: 0.9826
|
| 39 |
+
|
| 40 |
+
## Model description
|
| 41 |
+
|
| 42 |
+
More information needed
|
| 43 |
+
|
| 44 |
+
## Intended uses & limitations
|
| 45 |
+
|
| 46 |
+
More information needed
|
| 47 |
+
|
| 48 |
+
## Training and evaluation data
|
| 49 |
+
|
| 50 |
+
More information needed
|
| 51 |
+
|
| 52 |
+
## Training procedure
|
| 53 |
+
|
| 54 |
+
### Training hyperparameters
|
| 55 |
+
|
| 56 |
+
The following hyperparameters were used during training:
|
| 57 |
+
- learning_rate: 3e-05
|
| 58 |
+
- train_batch_size: 48
|
| 59 |
+
- eval_batch_size: 32
|
| 60 |
+
- seed: 0
|
| 61 |
+
- gradient_accumulation_steps: 4
|
| 62 |
+
- total_train_batch_size: 192
|
| 63 |
+
- optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
|
| 64 |
+
- lr_scheduler_type: linear
|
| 65 |
+
- lr_scheduler_warmup_ratio: 0.1
|
| 66 |
+
- num_epochs: 8.0
|
| 67 |
+
- mixed_precision_training: Native AMP
|
| 68 |
+
|
| 69 |
+
### Training results
|
| 70 |
+
|
| 71 |
+
| Training Loss | Epoch | Step | Validation Loss | Accuracy |
|
| 72 |
+
|:-------------:|:------:|:----:|:---------------:|:--------:|
|
| 73 |
+
| 1.3624 | 1.0 | 267 | 1.1959 | 0.6546 |
|
| 74 |
+
| 0.3854 | 2.0 | 534 | 0.2675 | 0.9734 |
|
| 75 |
+
| 0.2473 | 3.0 | 801 | 0.1461 | 0.9768 |
|
| 76 |
+
| 0.1997 | 4.0 | 1068 | 0.1088 | 0.9804 |
|
| 77 |
+
| 0.1723 | 5.0 | 1335 | 0.0954 | 0.9826 |
|
| 78 |
+
| 0.1442 | 6.0 | 1602 | 0.0927 | 0.9813 |
|
| 79 |
+
| 0.1397 | 7.0 | 1869 | 0.0892 | 0.9812 |
|
| 80 |
+
| 0.1368 | 7.9728 | 2128 | 0.0896 | 0.9812 |
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
### Framework versions
|
| 84 |
+
|
| 85 |
+
- Transformers 4.49.0
|
| 86 |
+
- Pytorch 2.6.0+cu118
|
| 87 |
+
- Datasets 3.3.1
|
| 88 |
+
- Tokenizers 0.21.0
|
hpo-examples/audio-classification/ac/all_results.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"epoch": 7.972769953051643,
|
| 3 |
+
"eval_accuracy": 0.9826419535157399,
|
| 4 |
+
"eval_loss": 0.09542840719223022,
|
| 5 |
+
"eval_runtime": 5.5538,
|
| 6 |
+
"eval_samples_per_second": 1224.023,
|
| 7 |
+
"eval_steps_per_second": 38.352,
|
| 8 |
+
"total_flos": 3.767900833756416e+18,
|
| 9 |
+
"train_loss": 0.5178930132572812,
|
| 10 |
+
"train_runtime": 756.2923,
|
| 11 |
+
"train_samples_per_second": 540.468,
|
| 12 |
+
"train_steps_per_second": 2.814
|
| 13 |
+
}
|
hpo-examples/audio-classification/ac/config.json
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "facebook/wav2vec2-base",
|
| 3 |
+
"activation_dropout": 0.0,
|
| 4 |
+
"adapter_attn_dim": null,
|
| 5 |
+
"adapter_kernel_size": 3,
|
| 6 |
+
"adapter_stride": 2,
|
| 7 |
+
"add_adapter": false,
|
| 8 |
+
"apply_spec_augment": true,
|
| 9 |
+
"architectures": [
|
| 10 |
+
"Wav2Vec2ForSequenceClassification"
|
| 11 |
+
],
|
| 12 |
+
"attention_dropout": 0.1,
|
| 13 |
+
"bos_token_id": 1,
|
| 14 |
+
"classifier_proj_size": 256,
|
| 15 |
+
"codevector_dim": 256,
|
| 16 |
+
"contrastive_logits_temperature": 0.1,
|
| 17 |
+
"conv_bias": false,
|
| 18 |
+
"conv_dim": [
|
| 19 |
+
512,
|
| 20 |
+
512,
|
| 21 |
+
512,
|
| 22 |
+
512,
|
| 23 |
+
512,
|
| 24 |
+
512,
|
| 25 |
+
512
|
| 26 |
+
],
|
| 27 |
+
"conv_kernel": [
|
| 28 |
+
10,
|
| 29 |
+
3,
|
| 30 |
+
3,
|
| 31 |
+
3,
|
| 32 |
+
3,
|
| 33 |
+
2,
|
| 34 |
+
2
|
| 35 |
+
],
|
| 36 |
+
"conv_stride": [
|
| 37 |
+
5,
|
| 38 |
+
2,
|
| 39 |
+
2,
|
| 40 |
+
2,
|
| 41 |
+
2,
|
| 42 |
+
2,
|
| 43 |
+
2
|
| 44 |
+
],
|
| 45 |
+
"ctc_loss_reduction": "sum",
|
| 46 |
+
"ctc_zero_infinity": false,
|
| 47 |
+
"diversity_loss_weight": 0.1,
|
| 48 |
+
"do_stable_layer_norm": false,
|
| 49 |
+
"eos_token_id": 2,
|
| 50 |
+
"feat_extract_activation": "gelu",
|
| 51 |
+
"feat_extract_norm": "group",
|
| 52 |
+
"feat_proj_dropout": 0.1,
|
| 53 |
+
"feat_quantizer_dropout": 0.0,
|
| 54 |
+
"final_dropout": 0.0,
|
| 55 |
+
"finetuning_task": "audio-classification",
|
| 56 |
+
"freeze_feat_extract_train": true,
|
| 57 |
+
"hidden_act": "gelu",
|
| 58 |
+
"hidden_dropout": 0.1,
|
| 59 |
+
"hidden_size": 768,
|
| 60 |
+
"id2label": {
|
| 61 |
+
"0": "yes",
|
| 62 |
+
"1": "no",
|
| 63 |
+
"10": "_silence_",
|
| 64 |
+
"11": "_unknown_",
|
| 65 |
+
"2": "up",
|
| 66 |
+
"3": "down",
|
| 67 |
+
"4": "left",
|
| 68 |
+
"5": "right",
|
| 69 |
+
"6": "on",
|
| 70 |
+
"7": "off",
|
| 71 |
+
"8": "stop",
|
| 72 |
+
"9": "go"
|
| 73 |
+
},
|
| 74 |
+
"initializer_range": 0.02,
|
| 75 |
+
"intermediate_size": 3072,
|
| 76 |
+
"label2id": {
|
| 77 |
+
"_silence_": "10",
|
| 78 |
+
"_unknown_": "11",
|
| 79 |
+
"down": "3",
|
| 80 |
+
"go": "9",
|
| 81 |
+
"left": "4",
|
| 82 |
+
"no": "1",
|
| 83 |
+
"off": "7",
|
| 84 |
+
"on": "6",
|
| 85 |
+
"right": "5",
|
| 86 |
+
"stop": "8",
|
| 87 |
+
"up": "2",
|
| 88 |
+
"yes": "0"
|
| 89 |
+
},
|
| 90 |
+
"layer_norm_eps": 1e-05,
|
| 91 |
+
"layerdrop": 0.0,
|
| 92 |
+
"mask_channel_length": 10,
|
| 93 |
+
"mask_channel_min_space": 1,
|
| 94 |
+
"mask_channel_other": 0.0,
|
| 95 |
+
"mask_channel_prob": 0.0,
|
| 96 |
+
"mask_channel_selection": "static",
|
| 97 |
+
"mask_feature_length": 10,
|
| 98 |
+
"mask_feature_min_masks": 0,
|
| 99 |
+
"mask_feature_prob": 0.0,
|
| 100 |
+
"mask_time_length": 10,
|
| 101 |
+
"mask_time_min_masks": 2,
|
| 102 |
+
"mask_time_min_space": 1,
|
| 103 |
+
"mask_time_other": 0.0,
|
| 104 |
+
"mask_time_prob": 0.05,
|
| 105 |
+
"mask_time_selection": "static",
|
| 106 |
+
"model_type": "wav2vec2",
|
| 107 |
+
"no_mask_channel_overlap": false,
|
| 108 |
+
"no_mask_time_overlap": false,
|
| 109 |
+
"num_adapter_layers": 3,
|
| 110 |
+
"num_attention_heads": 12,
|
| 111 |
+
"num_codevector_groups": 2,
|
| 112 |
+
"num_codevectors_per_group": 320,
|
| 113 |
+
"num_conv_pos_embedding_groups": 16,
|
| 114 |
+
"num_conv_pos_embeddings": 128,
|
| 115 |
+
"num_feat_extract_layers": 7,
|
| 116 |
+
"num_hidden_layers": 12,
|
| 117 |
+
"num_negatives": 100,
|
| 118 |
+
"output_hidden_size": 768,
|
| 119 |
+
"pad_token_id": 0,
|
| 120 |
+
"proj_codevector_dim": 256,
|
| 121 |
+
"tdnn_dilation": [
|
| 122 |
+
1,
|
| 123 |
+
2,
|
| 124 |
+
3,
|
| 125 |
+
1,
|
| 126 |
+
1
|
| 127 |
+
],
|
| 128 |
+
"tdnn_dim": [
|
| 129 |
+
512,
|
| 130 |
+
512,
|
| 131 |
+
512,
|
| 132 |
+
512,
|
| 133 |
+
1500
|
| 134 |
+
],
|
| 135 |
+
"tdnn_kernel": [
|
| 136 |
+
5,
|
| 137 |
+
3,
|
| 138 |
+
3,
|
| 139 |
+
1,
|
| 140 |
+
1
|
| 141 |
+
],
|
| 142 |
+
"torch_dtype": "float32",
|
| 143 |
+
"transformers_version": "4.49.0",
|
| 144 |
+
"use_weighted_layer_sum": false,
|
| 145 |
+
"vocab_size": 32,
|
| 146 |
+
"xvector_output_dim": 512
|
| 147 |
+
}
|
hpo-examples/audio-classification/ac/eval_results.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"epoch": 7.972769953051643,
|
| 3 |
+
"eval_accuracy": 0.9826419535157399,
|
| 4 |
+
"eval_loss": 0.09542840719223022,
|
| 5 |
+
"eval_runtime": 5.5538,
|
| 6 |
+
"eval_samples_per_second": 1224.023,
|
| 7 |
+
"eval_steps_per_second": 38.352
|
| 8 |
+
}
|
hpo-examples/audio-classification/ac/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d6e4f1c85d883f3e41ebfab4cd7752ab2e6d6b968b847795be22e1e0662657a3
|
| 3 |
+
size 385400352
|
hpo-examples/audio-classification/ac/preprocessor_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"do_normalize": true,
|
| 3 |
+
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
|
| 4 |
+
"feature_size": 1,
|
| 5 |
+
"padding_side": "right",
|
| 6 |
+
"padding_value": 0.0,
|
| 7 |
+
"return_attention_mask": false,
|
| 8 |
+
"sampling_rate": 16000
|
| 9 |
+
}
|
hpo-examples/audio-classification/ac/runs/May15_03-06-03_cs-Precision-7960-Tower/events.out.tfevents.1747292768.cs-Precision-7960-Tower.146737.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:75015326d15cb1786d8788b228e95bd7e952bad9e8e8be4ed02764cc17c8464c
|
| 3 |
+
size 54953
|
hpo-examples/audio-classification/ac/runs/May15_03-06-03_cs-Precision-7960-Tower/events.out.tfevents.1747293535.cs-Precision-7960-Tower.146737.1
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:215a99aa1f96db505ff53265c988d18ddcf70f16a1f13e713316a0354b768356
|
| 3 |
+
size 411
|
hpo-examples/audio-classification/ac/train_results.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"epoch": 7.972769953051643,
|
| 3 |
+
"total_flos": 3.767900833756416e+18,
|
| 4 |
+
"train_loss": 0.5178930132572812,
|
| 5 |
+
"train_runtime": 756.2923,
|
| 6 |
+
"train_samples_per_second": 540.468,
|
| 7 |
+
"train_steps_per_second": 2.814
|
| 8 |
+
}
|
hpo-examples/audio-classification/ac/trainer_state.json
ADDED
|
@@ -0,0 +1,1598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_metric": 0.9826419535157399,
|
| 3 |
+
"best_model_checkpoint": "wav2vec2-base-ft-keyword-spotting/checkpoint-1335",
|
| 4 |
+
"epoch": 7.972769953051643,
|
| 5 |
+
"eval_steps": 500,
|
| 6 |
+
"global_step": 2128,
|
| 7 |
+
"is_hyper_param_search": false,
|
| 8 |
+
"is_local_process_zero": true,
|
| 9 |
+
"is_world_process_zero": true,
|
| 10 |
+
"log_history": [
|
| 11 |
+
{
|
| 12 |
+
"epoch": 0.03755868544600939,
|
| 13 |
+
"grad_norm": 2.0377416610717773,
|
| 14 |
+
"learning_rate": 1.4084507042253521e-06,
|
| 15 |
+
"loss": 3.8687,
|
| 16 |
+
"step": 10
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"epoch": 0.07511737089201878,
|
| 20 |
+
"grad_norm": 3.055781602859497,
|
| 21 |
+
"learning_rate": 2.8169014084507042e-06,
|
| 22 |
+
"loss": 4.1156,
|
| 23 |
+
"step": 20
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"epoch": 0.11267605633802817,
|
| 27 |
+
"grad_norm": 3.383268356323242,
|
| 28 |
+
"learning_rate": 4.225352112676057e-06,
|
| 29 |
+
"loss": 4.0885,
|
| 30 |
+
"step": 30
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"epoch": 0.15023474178403756,
|
| 34 |
+
"grad_norm": 3.8566606044769287,
|
| 35 |
+
"learning_rate": 5.6338028169014084e-06,
|
| 36 |
+
"loss": 3.9316,
|
| 37 |
+
"step": 40
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"epoch": 0.18779342723004694,
|
| 41 |
+
"grad_norm": 5.065456867218018,
|
| 42 |
+
"learning_rate": 7.042253521126761e-06,
|
| 43 |
+
"loss": 3.6474,
|
| 44 |
+
"step": 50
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"epoch": 0.22535211267605634,
|
| 48 |
+
"grad_norm": 5.89341926574707,
|
| 49 |
+
"learning_rate": 8.450704225352114e-06,
|
| 50 |
+
"loss": 3.2124,
|
| 51 |
+
"step": 60
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"epoch": 0.26291079812206575,
|
| 55 |
+
"grad_norm": 5.9929399490356445,
|
| 56 |
+
"learning_rate": 9.859154929577466e-06,
|
| 57 |
+
"loss": 2.756,
|
| 58 |
+
"step": 70
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"epoch": 0.3004694835680751,
|
| 62 |
+
"grad_norm": 5.689433574676514,
|
| 63 |
+
"learning_rate": 1.1267605633802817e-05,
|
| 64 |
+
"loss": 2.4596,
|
| 65 |
+
"step": 80
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"epoch": 0.3380281690140845,
|
| 69 |
+
"grad_norm": 4.89589262008667,
|
| 70 |
+
"learning_rate": 1.267605633802817e-05,
|
| 71 |
+
"loss": 2.2638,
|
| 72 |
+
"step": 90
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"epoch": 0.3755868544600939,
|
| 76 |
+
"grad_norm": 4.8666839599609375,
|
| 77 |
+
"learning_rate": 1.4084507042253522e-05,
|
| 78 |
+
"loss": 2.1166,
|
| 79 |
+
"step": 100
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"epoch": 0.4131455399061033,
|
| 83 |
+
"grad_norm": 4.466708660125732,
|
| 84 |
+
"learning_rate": 1.5492957746478876e-05,
|
| 85 |
+
"loss": 2.0048,
|
| 86 |
+
"step": 110
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"epoch": 0.4507042253521127,
|
| 90 |
+
"grad_norm": 3.676050901412964,
|
| 91 |
+
"learning_rate": 1.6901408450704228e-05,
|
| 92 |
+
"loss": 1.9138,
|
| 93 |
+
"step": 120
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"epoch": 0.48826291079812206,
|
| 97 |
+
"grad_norm": 2.183825731277466,
|
| 98 |
+
"learning_rate": 1.830985915492958e-05,
|
| 99 |
+
"loss": 1.863,
|
| 100 |
+
"step": 130
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"epoch": 0.5258215962441315,
|
| 104 |
+
"grad_norm": 2.075413465499878,
|
| 105 |
+
"learning_rate": 1.9718309859154933e-05,
|
| 106 |
+
"loss": 1.7616,
|
| 107 |
+
"step": 140
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"epoch": 0.5633802816901409,
|
| 111 |
+
"grad_norm": 0.8534318208694458,
|
| 112 |
+
"learning_rate": 2.112676056338028e-05,
|
| 113 |
+
"loss": 1.7185,
|
| 114 |
+
"step": 150
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"epoch": 0.6009389671361502,
|
| 118 |
+
"grad_norm": 0.9039830565452576,
|
| 119 |
+
"learning_rate": 2.2535211267605634e-05,
|
| 120 |
+
"loss": 1.8054,
|
| 121 |
+
"step": 160
|
| 122 |
+
},
|
| 123 |
+
{
|
| 124 |
+
"epoch": 0.6384976525821596,
|
| 125 |
+
"grad_norm": 1.32124662399292,
|
| 126 |
+
"learning_rate": 2.3943661971830986e-05,
|
| 127 |
+
"loss": 1.7367,
|
| 128 |
+
"step": 170
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"epoch": 0.676056338028169,
|
| 132 |
+
"grad_norm": 1.232069969177246,
|
| 133 |
+
"learning_rate": 2.535211267605634e-05,
|
| 134 |
+
"loss": 1.7423,
|
| 135 |
+
"step": 180
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"epoch": 0.7136150234741784,
|
| 139 |
+
"grad_norm": 1.9570960998535156,
|
| 140 |
+
"learning_rate": 2.676056338028169e-05,
|
| 141 |
+
"loss": 1.6132,
|
| 142 |
+
"step": 190
|
| 143 |
+
},
|
| 144 |
+
{
|
| 145 |
+
"epoch": 0.7511737089201878,
|
| 146 |
+
"grad_norm": 2.4463119506835938,
|
| 147 |
+
"learning_rate": 2.8169014084507043e-05,
|
| 148 |
+
"loss": 1.6099,
|
| 149 |
+
"step": 200
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"epoch": 0.7887323943661971,
|
| 153 |
+
"grad_norm": 6.601908206939697,
|
| 154 |
+
"learning_rate": 2.9577464788732395e-05,
|
| 155 |
+
"loss": 1.6043,
|
| 156 |
+
"step": 210
|
| 157 |
+
},
|
| 158 |
+
{
|
| 159 |
+
"epoch": 0.8262910798122066,
|
| 160 |
+
"grad_norm": 3.225101947784424,
|
| 161 |
+
"learning_rate": 2.989033942558747e-05,
|
| 162 |
+
"loss": 1.5621,
|
| 163 |
+
"step": 220
|
| 164 |
+
},
|
| 165 |
+
{
|
| 166 |
+
"epoch": 0.863849765258216,
|
| 167 |
+
"grad_norm": 3.698263645172119,
|
| 168 |
+
"learning_rate": 2.9733681462140994e-05,
|
| 169 |
+
"loss": 1.514,
|
| 170 |
+
"step": 230
|
| 171 |
+
},
|
| 172 |
+
{
|
| 173 |
+
"epoch": 0.9014084507042254,
|
| 174 |
+
"grad_norm": 5.209756374359131,
|
| 175 |
+
"learning_rate": 2.9577023498694518e-05,
|
| 176 |
+
"loss": 1.4532,
|
| 177 |
+
"step": 240
|
| 178 |
+
},
|
| 179 |
+
{
|
| 180 |
+
"epoch": 0.9389671361502347,
|
| 181 |
+
"grad_norm": 2.1304848194122314,
|
| 182 |
+
"learning_rate": 2.9420365535248042e-05,
|
| 183 |
+
"loss": 1.4312,
|
| 184 |
+
"step": 250
|
| 185 |
+
},
|
| 186 |
+
{
|
| 187 |
+
"epoch": 0.9765258215962441,
|
| 188 |
+
"grad_norm": 4.837350368499756,
|
| 189 |
+
"learning_rate": 2.926370757180157e-05,
|
| 190 |
+
"loss": 1.3624,
|
| 191 |
+
"step": 260
|
| 192 |
+
},
|
| 193 |
+
{
|
| 194 |
+
"epoch": 1.0,
|
| 195 |
+
"eval_accuracy": 0.6546042953809944,
|
| 196 |
+
"eval_loss": 1.19585382938385,
|
| 197 |
+
"eval_runtime": 4.9178,
|
| 198 |
+
"eval_samples_per_second": 1382.328,
|
| 199 |
+
"eval_steps_per_second": 43.312,
|
| 200 |
+
"step": 267
|
| 201 |
+
},
|
| 202 |
+
{
|
| 203 |
+
"epoch": 1.0112676056338028,
|
| 204 |
+
"grad_norm": 4.779292106628418,
|
| 205 |
+
"learning_rate": 2.9107049608355094e-05,
|
| 206 |
+
"loss": 1.2541,
|
| 207 |
+
"step": 270
|
| 208 |
+
},
|
| 209 |
+
{
|
| 210 |
+
"epoch": 1.0488262910798123,
|
| 211 |
+
"grad_norm": 3.60760498046875,
|
| 212 |
+
"learning_rate": 2.8950391644908618e-05,
|
| 213 |
+
"loss": 1.2271,
|
| 214 |
+
"step": 280
|
| 215 |
+
},
|
| 216 |
+
{
|
| 217 |
+
"epoch": 1.0863849765258216,
|
| 218 |
+
"grad_norm": 2.3788599967956543,
|
| 219 |
+
"learning_rate": 2.8793733681462142e-05,
|
| 220 |
+
"loss": 1.2335,
|
| 221 |
+
"step": 290
|
| 222 |
+
},
|
| 223 |
+
{
|
| 224 |
+
"epoch": 1.123943661971831,
|
| 225 |
+
"grad_norm": 3.353325843811035,
|
| 226 |
+
"learning_rate": 2.8637075718015666e-05,
|
| 227 |
+
"loss": 1.1613,
|
| 228 |
+
"step": 300
|
| 229 |
+
},
|
| 230 |
+
{
|
| 231 |
+
"epoch": 1.1615023474178403,
|
| 232 |
+
"grad_norm": 4.326411247253418,
|
| 233 |
+
"learning_rate": 2.8480417754569193e-05,
|
| 234 |
+
"loss": 1.0754,
|
| 235 |
+
"step": 310
|
| 236 |
+
},
|
| 237 |
+
{
|
| 238 |
+
"epoch": 1.1990610328638498,
|
| 239 |
+
"grad_norm": 3.1939706802368164,
|
| 240 |
+
"learning_rate": 2.8323759791122717e-05,
|
| 241 |
+
"loss": 1.0353,
|
| 242 |
+
"step": 320
|
| 243 |
+
},
|
| 244 |
+
{
|
| 245 |
+
"epoch": 1.236619718309859,
|
| 246 |
+
"grad_norm": 2.8827011585235596,
|
| 247 |
+
"learning_rate": 2.816710182767624e-05,
|
| 248 |
+
"loss": 0.9806,
|
| 249 |
+
"step": 330
|
| 250 |
+
},
|
| 251 |
+
{
|
| 252 |
+
"epoch": 1.2741784037558685,
|
| 253 |
+
"grad_norm": 3.910698652267456,
|
| 254 |
+
"learning_rate": 2.8010443864229766e-05,
|
| 255 |
+
"loss": 1.0813,
|
| 256 |
+
"step": 340
|
| 257 |
+
},
|
| 258 |
+
{
|
| 259 |
+
"epoch": 1.3117370892018778,
|
| 260 |
+
"grad_norm": 3.5916378498077393,
|
| 261 |
+
"learning_rate": 2.7853785900783293e-05,
|
| 262 |
+
"loss": 0.9792,
|
| 263 |
+
"step": 350
|
| 264 |
+
},
|
| 265 |
+
{
|
| 266 |
+
"epoch": 1.3492957746478873,
|
| 267 |
+
"grad_norm": 2.6981167793273926,
|
| 268 |
+
"learning_rate": 2.7697127937336817e-05,
|
| 269 |
+
"loss": 0.9231,
|
| 270 |
+
"step": 360
|
| 271 |
+
},
|
| 272 |
+
{
|
| 273 |
+
"epoch": 1.3868544600938968,
|
| 274 |
+
"grad_norm": 5.702897071838379,
|
| 275 |
+
"learning_rate": 2.754046997389034e-05,
|
| 276 |
+
"loss": 0.9435,
|
| 277 |
+
"step": 370
|
| 278 |
+
},
|
| 279 |
+
{
|
| 280 |
+
"epoch": 1.424413145539906,
|
| 281 |
+
"grad_norm": 4.622363090515137,
|
| 282 |
+
"learning_rate": 2.7383812010443865e-05,
|
| 283 |
+
"loss": 0.8449,
|
| 284 |
+
"step": 380
|
| 285 |
+
},
|
| 286 |
+
{
|
| 287 |
+
"epoch": 1.4619718309859155,
|
| 288 |
+
"grad_norm": 2.2103636264801025,
|
| 289 |
+
"learning_rate": 2.7227154046997393e-05,
|
| 290 |
+
"loss": 0.7713,
|
| 291 |
+
"step": 390
|
| 292 |
+
},
|
| 293 |
+
{
|
| 294 |
+
"epoch": 1.4995305164319248,
|
| 295 |
+
"grad_norm": 4.545182228088379,
|
| 296 |
+
"learning_rate": 2.7070496083550917e-05,
|
| 297 |
+
"loss": 0.7719,
|
| 298 |
+
"step": 400
|
| 299 |
+
},
|
| 300 |
+
{
|
| 301 |
+
"epoch": 1.5370892018779343,
|
| 302 |
+
"grad_norm": 6.883026599884033,
|
| 303 |
+
"learning_rate": 2.691383812010444e-05,
|
| 304 |
+
"loss": 0.7564,
|
| 305 |
+
"step": 410
|
| 306 |
+
},
|
| 307 |
+
{
|
| 308 |
+
"epoch": 1.5746478873239438,
|
| 309 |
+
"grad_norm": 4.770920276641846,
|
| 310 |
+
"learning_rate": 2.6757180156657965e-05,
|
| 311 |
+
"loss": 0.6994,
|
| 312 |
+
"step": 420
|
| 313 |
+
},
|
| 314 |
+
{
|
| 315 |
+
"epoch": 1.612206572769953,
|
| 316 |
+
"grad_norm": 4.413459300994873,
|
| 317 |
+
"learning_rate": 2.660052219321149e-05,
|
| 318 |
+
"loss": 0.6313,
|
| 319 |
+
"step": 430
|
| 320 |
+
},
|
| 321 |
+
{
|
| 322 |
+
"epoch": 1.6497652582159623,
|
| 323 |
+
"grad_norm": 2.0261390209198,
|
| 324 |
+
"learning_rate": 2.6443864229765013e-05,
|
| 325 |
+
"loss": 0.6017,
|
| 326 |
+
"step": 440
|
| 327 |
+
},
|
| 328 |
+
{
|
| 329 |
+
"epoch": 1.6873239436619718,
|
| 330 |
+
"grad_norm": 5.67121696472168,
|
| 331 |
+
"learning_rate": 2.6287206266318537e-05,
|
| 332 |
+
"loss": 0.5792,
|
| 333 |
+
"step": 450
|
| 334 |
+
},
|
| 335 |
+
{
|
| 336 |
+
"epoch": 1.7248826291079813,
|
| 337 |
+
"grad_norm": 2.573594808578491,
|
| 338 |
+
"learning_rate": 2.6146214099216712e-05,
|
| 339 |
+
"loss": 0.545,
|
| 340 |
+
"step": 460
|
| 341 |
+
},
|
| 342 |
+
{
|
| 343 |
+
"epoch": 1.7624413145539906,
|
| 344 |
+
"grad_norm": 4.145854949951172,
|
| 345 |
+
"learning_rate": 2.5989556135770236e-05,
|
| 346 |
+
"loss": 0.4907,
|
| 347 |
+
"step": 470
|
| 348 |
+
},
|
| 349 |
+
{
|
| 350 |
+
"epoch": 1.8,
|
| 351 |
+
"grad_norm": 1.7418975830078125,
|
| 352 |
+
"learning_rate": 2.583289817232376e-05,
|
| 353 |
+
"loss": 0.485,
|
| 354 |
+
"step": 480
|
| 355 |
+
},
|
| 356 |
+
{
|
| 357 |
+
"epoch": 1.8375586854460093,
|
| 358 |
+
"grad_norm": 4.651867866516113,
|
| 359 |
+
"learning_rate": 2.5676240208877287e-05,
|
| 360 |
+
"loss": 0.4572,
|
| 361 |
+
"step": 490
|
| 362 |
+
},
|
| 363 |
+
{
|
| 364 |
+
"epoch": 1.8751173708920188,
|
| 365 |
+
"grad_norm": 4.849829196929932,
|
| 366 |
+
"learning_rate": 2.551958224543081e-05,
|
| 367 |
+
"loss": 0.4864,
|
| 368 |
+
"step": 500
|
| 369 |
+
},
|
| 370 |
+
{
|
| 371 |
+
"epoch": 1.9126760563380283,
|
| 372 |
+
"grad_norm": 2.631229877471924,
|
| 373 |
+
"learning_rate": 2.5362924281984335e-05,
|
| 374 |
+
"loss": 0.4035,
|
| 375 |
+
"step": 510
|
| 376 |
+
},
|
| 377 |
+
{
|
| 378 |
+
"epoch": 1.9502347417840376,
|
| 379 |
+
"grad_norm": 5.099828243255615,
|
| 380 |
+
"learning_rate": 2.520626631853786e-05,
|
| 381 |
+
"loss": 0.3818,
|
| 382 |
+
"step": 520
|
| 383 |
+
},
|
| 384 |
+
{
|
| 385 |
+
"epoch": 1.9877934272300468,
|
| 386 |
+
"grad_norm": 3.25174617767334,
|
| 387 |
+
"learning_rate": 2.5049608355091387e-05,
|
| 388 |
+
"loss": 0.3854,
|
| 389 |
+
"step": 530
|
| 390 |
+
},
|
| 391 |
+
{
|
| 392 |
+
"epoch": 2.0,
|
| 393 |
+
"eval_accuracy": 0.9733745219182113,
|
| 394 |
+
"eval_loss": 0.2675245702266693,
|
| 395 |
+
"eval_runtime": 5.0739,
|
| 396 |
+
"eval_samples_per_second": 1339.801,
|
| 397 |
+
"eval_steps_per_second": 41.98,
|
| 398 |
+
"step": 534
|
| 399 |
+
},
|
| 400 |
+
{
|
| 401 |
+
"epoch": 2.0225352112676056,
|
| 402 |
+
"grad_norm": 4.533545017242432,
|
| 403 |
+
"learning_rate": 2.489295039164491e-05,
|
| 404 |
+
"loss": 0.3589,
|
| 405 |
+
"step": 540
|
| 406 |
+
},
|
| 407 |
+
{
|
| 408 |
+
"epoch": 2.060093896713615,
|
| 409 |
+
"grad_norm": 4.4245991706848145,
|
| 410 |
+
"learning_rate": 2.4736292428198435e-05,
|
| 411 |
+
"loss": 0.378,
|
| 412 |
+
"step": 550
|
| 413 |
+
},
|
| 414 |
+
{
|
| 415 |
+
"epoch": 2.0976525821596246,
|
| 416 |
+
"grad_norm": 5.778880596160889,
|
| 417 |
+
"learning_rate": 2.457963446475196e-05,
|
| 418 |
+
"loss": 0.3653,
|
| 419 |
+
"step": 560
|
| 420 |
+
},
|
| 421 |
+
{
|
| 422 |
+
"epoch": 2.1352112676056336,
|
| 423 |
+
"grad_norm": 3.5573890209198,
|
| 424 |
+
"learning_rate": 2.4422976501305487e-05,
|
| 425 |
+
"loss": 0.3107,
|
| 426 |
+
"step": 570
|
| 427 |
+
},
|
| 428 |
+
{
|
| 429 |
+
"epoch": 2.172769953051643,
|
| 430 |
+
"grad_norm": 3.655824899673462,
|
| 431 |
+
"learning_rate": 2.426631853785901e-05,
|
| 432 |
+
"loss": 0.3405,
|
| 433 |
+
"step": 580
|
| 434 |
+
},
|
| 435 |
+
{
|
| 436 |
+
"epoch": 2.2103286384976526,
|
| 437 |
+
"grad_norm": 2.430022954940796,
|
| 438 |
+
"learning_rate": 2.4109660574412535e-05,
|
| 439 |
+
"loss": 0.3298,
|
| 440 |
+
"step": 590
|
| 441 |
+
},
|
| 442 |
+
{
|
| 443 |
+
"epoch": 2.247887323943662,
|
| 444 |
+
"grad_norm": 2.9207568168640137,
|
| 445 |
+
"learning_rate": 2.3953002610966055e-05,
|
| 446 |
+
"loss": 0.3022,
|
| 447 |
+
"step": 600
|
| 448 |
+
},
|
| 449 |
+
{
|
| 450 |
+
"epoch": 2.2854460093896716,
|
| 451 |
+
"grad_norm": 4.8787007331848145,
|
| 452 |
+
"learning_rate": 2.3796344647519583e-05,
|
| 453 |
+
"loss": 0.3991,
|
| 454 |
+
"step": 610
|
| 455 |
+
},
|
| 456 |
+
{
|
| 457 |
+
"epoch": 2.3230046948356806,
|
| 458 |
+
"grad_norm": 3.0268468856811523,
|
| 459 |
+
"learning_rate": 2.3639686684073107e-05,
|
| 460 |
+
"loss": 0.3159,
|
| 461 |
+
"step": 620
|
| 462 |
+
},
|
| 463 |
+
{
|
| 464 |
+
"epoch": 2.36056338028169,
|
| 465 |
+
"grad_norm": 2.6611557006835938,
|
| 466 |
+
"learning_rate": 2.348302872062663e-05,
|
| 467 |
+
"loss": 0.2868,
|
| 468 |
+
"step": 630
|
| 469 |
+
},
|
| 470 |
+
{
|
| 471 |
+
"epoch": 2.3981220657276996,
|
| 472 |
+
"grad_norm": 2.485551595687866,
|
| 473 |
+
"learning_rate": 2.3326370757180155e-05,
|
| 474 |
+
"loss": 0.3032,
|
| 475 |
+
"step": 640
|
| 476 |
+
},
|
| 477 |
+
{
|
| 478 |
+
"epoch": 2.435680751173709,
|
| 479 |
+
"grad_norm": 4.556153297424316,
|
| 480 |
+
"learning_rate": 2.316971279373368e-05,
|
| 481 |
+
"loss": 0.2985,
|
| 482 |
+
"step": 650
|
| 483 |
+
},
|
| 484 |
+
{
|
| 485 |
+
"epoch": 2.473239436619718,
|
| 486 |
+
"grad_norm": 5.270796298980713,
|
| 487 |
+
"learning_rate": 2.3013054830287207e-05,
|
| 488 |
+
"loss": 0.2839,
|
| 489 |
+
"step": 660
|
| 490 |
+
},
|
| 491 |
+
{
|
| 492 |
+
"epoch": 2.5107981220657276,
|
| 493 |
+
"grad_norm": 3.347005844116211,
|
| 494 |
+
"learning_rate": 2.285639686684073e-05,
|
| 495 |
+
"loss": 0.2871,
|
| 496 |
+
"step": 670
|
| 497 |
+
},
|
| 498 |
+
{
|
| 499 |
+
"epoch": 2.548356807511737,
|
| 500 |
+
"grad_norm": 5.236591815948486,
|
| 501 |
+
"learning_rate": 2.2699738903394255e-05,
|
| 502 |
+
"loss": 0.3028,
|
| 503 |
+
"step": 680
|
| 504 |
+
},
|
| 505 |
+
{
|
| 506 |
+
"epoch": 2.5859154929577466,
|
| 507 |
+
"grad_norm": 2.995059013366699,
|
| 508 |
+
"learning_rate": 2.254308093994778e-05,
|
| 509 |
+
"loss": 0.2537,
|
| 510 |
+
"step": 690
|
| 511 |
+
},
|
| 512 |
+
{
|
| 513 |
+
"epoch": 2.6234741784037556,
|
| 514 |
+
"grad_norm": 2.805640459060669,
|
| 515 |
+
"learning_rate": 2.2386422976501306e-05,
|
| 516 |
+
"loss": 0.297,
|
| 517 |
+
"step": 700
|
| 518 |
+
},
|
| 519 |
+
{
|
| 520 |
+
"epoch": 2.661032863849765,
|
| 521 |
+
"grad_norm": 3.0646071434020996,
|
| 522 |
+
"learning_rate": 2.222976501305483e-05,
|
| 523 |
+
"loss": 0.2453,
|
| 524 |
+
"step": 710
|
| 525 |
+
},
|
| 526 |
+
{
|
| 527 |
+
"epoch": 2.6985915492957746,
|
| 528 |
+
"grad_norm": 3.6719613075256348,
|
| 529 |
+
"learning_rate": 2.2073107049608354e-05,
|
| 530 |
+
"loss": 0.2655,
|
| 531 |
+
"step": 720
|
| 532 |
+
},
|
| 533 |
+
{
|
| 534 |
+
"epoch": 2.736150234741784,
|
| 535 |
+
"grad_norm": 3.2248122692108154,
|
| 536 |
+
"learning_rate": 2.191644908616188e-05,
|
| 537 |
+
"loss": 0.2297,
|
| 538 |
+
"step": 730
|
| 539 |
+
},
|
| 540 |
+
{
|
| 541 |
+
"epoch": 2.7737089201877936,
|
| 542 |
+
"grad_norm": 3.769843578338623,
|
| 543 |
+
"learning_rate": 2.1759791122715406e-05,
|
| 544 |
+
"loss": 0.2548,
|
| 545 |
+
"step": 740
|
| 546 |
+
},
|
| 547 |
+
{
|
| 548 |
+
"epoch": 2.8112676056338026,
|
| 549 |
+
"grad_norm": 3.6679906845092773,
|
| 550 |
+
"learning_rate": 2.160313315926893e-05,
|
| 551 |
+
"loss": 0.2836,
|
| 552 |
+
"step": 750
|
| 553 |
+
},
|
| 554 |
+
{
|
| 555 |
+
"epoch": 2.848826291079812,
|
| 556 |
+
"grad_norm": 1.6924936771392822,
|
| 557 |
+
"learning_rate": 2.1446475195822454e-05,
|
| 558 |
+
"loss": 0.2555,
|
| 559 |
+
"step": 760
|
| 560 |
+
},
|
| 561 |
+
{
|
| 562 |
+
"epoch": 2.8863849765258216,
|
| 563 |
+
"grad_norm": 2.1275901794433594,
|
| 564 |
+
"learning_rate": 2.1289817232375978e-05,
|
| 565 |
+
"loss": 0.2334,
|
| 566 |
+
"step": 770
|
| 567 |
+
},
|
| 568 |
+
{
|
| 569 |
+
"epoch": 2.923943661971831,
|
| 570 |
+
"grad_norm": 6.528135299682617,
|
| 571 |
+
"learning_rate": 2.1133159268929506e-05,
|
| 572 |
+
"loss": 0.2544,
|
| 573 |
+
"step": 780
|
| 574 |
+
},
|
| 575 |
+
{
|
| 576 |
+
"epoch": 2.9615023474178406,
|
| 577 |
+
"grad_norm": 2.4497199058532715,
|
| 578 |
+
"learning_rate": 2.097650130548303e-05,
|
| 579 |
+
"loss": 0.2628,
|
| 580 |
+
"step": 790
|
| 581 |
+
},
|
| 582 |
+
{
|
| 583 |
+
"epoch": 2.9990610328638496,
|
| 584 |
+
"grad_norm": 2.278947591781616,
|
| 585 |
+
"learning_rate": 2.0819843342036554e-05,
|
| 586 |
+
"loss": 0.2473,
|
| 587 |
+
"step": 800
|
| 588 |
+
},
|
| 589 |
+
{
|
| 590 |
+
"epoch": 3.0,
|
| 591 |
+
"eval_accuracy": 0.9767578699617535,
|
| 592 |
+
"eval_loss": 0.1461225152015686,
|
| 593 |
+
"eval_runtime": 5.0057,
|
| 594 |
+
"eval_samples_per_second": 1358.045,
|
| 595 |
+
"eval_steps_per_second": 42.551,
|
| 596 |
+
"step": 801
|
| 597 |
+
},
|
| 598 |
+
{
|
| 599 |
+
"epoch": 3.0338028169014084,
|
| 600 |
+
"grad_norm": 3.1185402870178223,
|
| 601 |
+
"learning_rate": 2.0663185378590078e-05,
|
| 602 |
+
"loss": 0.2245,
|
| 603 |
+
"step": 810
|
| 604 |
+
},
|
| 605 |
+
{
|
| 606 |
+
"epoch": 3.071361502347418,
|
| 607 |
+
"grad_norm": 2.456102132797241,
|
| 608 |
+
"learning_rate": 2.0506527415143602e-05,
|
| 609 |
+
"loss": 0.2423,
|
| 610 |
+
"step": 820
|
| 611 |
+
},
|
| 612 |
+
{
|
| 613 |
+
"epoch": 3.1089201877934274,
|
| 614 |
+
"grad_norm": 2.9463231563568115,
|
| 615 |
+
"learning_rate": 2.034986945169713e-05,
|
| 616 |
+
"loss": 0.2274,
|
| 617 |
+
"step": 830
|
| 618 |
+
},
|
| 619 |
+
{
|
| 620 |
+
"epoch": 3.1464788732394364,
|
| 621 |
+
"grad_norm": 3.5940473079681396,
|
| 622 |
+
"learning_rate": 2.0193211488250653e-05,
|
| 623 |
+
"loss": 0.2368,
|
| 624 |
+
"step": 840
|
| 625 |
+
},
|
| 626 |
+
{
|
| 627 |
+
"epoch": 3.184037558685446,
|
| 628 |
+
"grad_norm": 4.721577167510986,
|
| 629 |
+
"learning_rate": 2.0036553524804177e-05,
|
| 630 |
+
"loss": 0.2554,
|
| 631 |
+
"step": 850
|
| 632 |
+
},
|
| 633 |
+
{
|
| 634 |
+
"epoch": 3.2215962441314554,
|
| 635 |
+
"grad_norm": 2.496495485305786,
|
| 636 |
+
"learning_rate": 1.98798955613577e-05,
|
| 637 |
+
"loss": 0.2363,
|
| 638 |
+
"step": 860
|
| 639 |
+
},
|
| 640 |
+
{
|
| 641 |
+
"epoch": 3.259154929577465,
|
| 642 |
+
"grad_norm": 3.0665740966796875,
|
| 643 |
+
"learning_rate": 1.972323759791123e-05,
|
| 644 |
+
"loss": 0.2248,
|
| 645 |
+
"step": 870
|
| 646 |
+
},
|
| 647 |
+
{
|
| 648 |
+
"epoch": 3.2967136150234744,
|
| 649 |
+
"grad_norm": 4.336172580718994,
|
| 650 |
+
"learning_rate": 1.9566579634464753e-05,
|
| 651 |
+
"loss": 0.1922,
|
| 652 |
+
"step": 880
|
| 653 |
+
},
|
| 654 |
+
{
|
| 655 |
+
"epoch": 3.3342723004694834,
|
| 656 |
+
"grad_norm": 4.110763072967529,
|
| 657 |
+
"learning_rate": 1.9409921671018277e-05,
|
| 658 |
+
"loss": 0.1965,
|
| 659 |
+
"step": 890
|
| 660 |
+
},
|
| 661 |
+
{
|
| 662 |
+
"epoch": 3.371830985915493,
|
| 663 |
+
"grad_norm": 1.9457247257232666,
|
| 664 |
+
"learning_rate": 1.92532637075718e-05,
|
| 665 |
+
"loss": 0.2258,
|
| 666 |
+
"step": 900
|
| 667 |
+
},
|
| 668 |
+
{
|
| 669 |
+
"epoch": 3.4093896713615024,
|
| 670 |
+
"grad_norm": 2.719369411468506,
|
| 671 |
+
"learning_rate": 1.909660574412533e-05,
|
| 672 |
+
"loss": 0.2184,
|
| 673 |
+
"step": 910
|
| 674 |
+
},
|
| 675 |
+
{
|
| 676 |
+
"epoch": 3.446948356807512,
|
| 677 |
+
"grad_norm": 3.438279151916504,
|
| 678 |
+
"learning_rate": 1.8939947780678853e-05,
|
| 679 |
+
"loss": 0.1964,
|
| 680 |
+
"step": 920
|
| 681 |
+
},
|
| 682 |
+
{
|
| 683 |
+
"epoch": 3.4845070422535214,
|
| 684 |
+
"grad_norm": 3.2813045978546143,
|
| 685 |
+
"learning_rate": 1.8783289817232377e-05,
|
| 686 |
+
"loss": 0.2348,
|
| 687 |
+
"step": 930
|
| 688 |
+
},
|
| 689 |
+
{
|
| 690 |
+
"epoch": 3.5220657276995304,
|
| 691 |
+
"grad_norm": 4.151478290557861,
|
| 692 |
+
"learning_rate": 1.86266318537859e-05,
|
| 693 |
+
"loss": 0.2004,
|
| 694 |
+
"step": 940
|
| 695 |
+
},
|
| 696 |
+
{
|
| 697 |
+
"epoch": 3.55962441314554,
|
| 698 |
+
"grad_norm": 3.4271771907806396,
|
| 699 |
+
"learning_rate": 1.8469973890339425e-05,
|
| 700 |
+
"loss": 0.2039,
|
| 701 |
+
"step": 950
|
| 702 |
+
},
|
| 703 |
+
{
|
| 704 |
+
"epoch": 3.5971830985915494,
|
| 705 |
+
"grad_norm": 4.0341901779174805,
|
| 706 |
+
"learning_rate": 1.8313315926892952e-05,
|
| 707 |
+
"loss": 0.1997,
|
| 708 |
+
"step": 960
|
| 709 |
+
},
|
| 710 |
+
{
|
| 711 |
+
"epoch": 3.6347417840375584,
|
| 712 |
+
"grad_norm": 4.762091636657715,
|
| 713 |
+
"learning_rate": 1.8156657963446476e-05,
|
| 714 |
+
"loss": 0.2153,
|
| 715 |
+
"step": 970
|
| 716 |
+
},
|
| 717 |
+
{
|
| 718 |
+
"epoch": 3.672300469483568,
|
| 719 |
+
"grad_norm": 3.3214402198791504,
|
| 720 |
+
"learning_rate": 1.8e-05,
|
| 721 |
+
"loss": 0.1801,
|
| 722 |
+
"step": 980
|
| 723 |
+
},
|
| 724 |
+
{
|
| 725 |
+
"epoch": 3.7098591549295774,
|
| 726 |
+
"grad_norm": 3.84503173828125,
|
| 727 |
+
"learning_rate": 1.7843342036553525e-05,
|
| 728 |
+
"loss": 0.2106,
|
| 729 |
+
"step": 990
|
| 730 |
+
},
|
| 731 |
+
{
|
| 732 |
+
"epoch": 3.747417840375587,
|
| 733 |
+
"grad_norm": 3.303781747817993,
|
| 734 |
+
"learning_rate": 1.7686684073107052e-05,
|
| 735 |
+
"loss": 0.1965,
|
| 736 |
+
"step": 1000
|
| 737 |
+
},
|
| 738 |
+
{
|
| 739 |
+
"epoch": 3.7849765258215964,
|
| 740 |
+
"grad_norm": 2.691159248352051,
|
| 741 |
+
"learning_rate": 1.7530026109660576e-05,
|
| 742 |
+
"loss": 0.193,
|
| 743 |
+
"step": 1010
|
| 744 |
+
},
|
| 745 |
+
{
|
| 746 |
+
"epoch": 3.8225352112676054,
|
| 747 |
+
"grad_norm": 4.134768009185791,
|
| 748 |
+
"learning_rate": 1.73733681462141e-05,
|
| 749 |
+
"loss": 0.1908,
|
| 750 |
+
"step": 1020
|
| 751 |
+
},
|
| 752 |
+
{
|
| 753 |
+
"epoch": 3.860093896713615,
|
| 754 |
+
"grad_norm": 2.9195241928100586,
|
| 755 |
+
"learning_rate": 1.7216710182767624e-05,
|
| 756 |
+
"loss": 0.1886,
|
| 757 |
+
"step": 1030
|
| 758 |
+
},
|
| 759 |
+
{
|
| 760 |
+
"epoch": 3.8976525821596244,
|
| 761 |
+
"grad_norm": 3.795133352279663,
|
| 762 |
+
"learning_rate": 1.706005221932115e-05,
|
| 763 |
+
"loss": 0.2007,
|
| 764 |
+
"step": 1040
|
| 765 |
+
},
|
| 766 |
+
{
|
| 767 |
+
"epoch": 3.935211267605634,
|
| 768 |
+
"grad_norm": 3.9436607360839844,
|
| 769 |
+
"learning_rate": 1.6903394255874676e-05,
|
| 770 |
+
"loss": 0.1834,
|
| 771 |
+
"step": 1050
|
| 772 |
+
},
|
| 773 |
+
{
|
| 774 |
+
"epoch": 3.9727699530516434,
|
| 775 |
+
"grad_norm": 3.4115564823150635,
|
| 776 |
+
"learning_rate": 1.67467362924282e-05,
|
| 777 |
+
"loss": 0.1997,
|
| 778 |
+
"step": 1060
|
| 779 |
+
},
|
| 780 |
+
{
|
| 781 |
+
"epoch": 4.0,
|
| 782 |
+
"eval_accuracy": 0.980435422182995,
|
| 783 |
+
"eval_loss": 0.10877315700054169,
|
| 784 |
+
"eval_runtime": 4.9191,
|
| 785 |
+
"eval_samples_per_second": 1381.955,
|
| 786 |
+
"eval_steps_per_second": 43.3,
|
| 787 |
+
"step": 1068
|
| 788 |
+
},
|
| 789 |
+
{
|
| 790 |
+
"epoch": 4.007511737089202,
|
| 791 |
+
"grad_norm": 5.121041774749756,
|
| 792 |
+
"learning_rate": 1.6590078328981724e-05,
|
| 793 |
+
"loss": 0.1785,
|
| 794 |
+
"step": 1070
|
| 795 |
+
},
|
| 796 |
+
{
|
| 797 |
+
"epoch": 4.045070422535211,
|
| 798 |
+
"grad_norm": 2.908527374267578,
|
| 799 |
+
"learning_rate": 1.643342036553525e-05,
|
| 800 |
+
"loss": 0.1678,
|
| 801 |
+
"step": 1080
|
| 802 |
+
},
|
| 803 |
+
{
|
| 804 |
+
"epoch": 4.08262910798122,
|
| 805 |
+
"grad_norm": 1.9687402248382568,
|
| 806 |
+
"learning_rate": 1.6276762402088775e-05,
|
| 807 |
+
"loss": 0.192,
|
| 808 |
+
"step": 1090
|
| 809 |
+
},
|
| 810 |
+
{
|
| 811 |
+
"epoch": 4.12018779342723,
|
| 812 |
+
"grad_norm": 2.722937822341919,
|
| 813 |
+
"learning_rate": 1.61201044386423e-05,
|
| 814 |
+
"loss": 0.1983,
|
| 815 |
+
"step": 1100
|
| 816 |
+
},
|
| 817 |
+
{
|
| 818 |
+
"epoch": 4.157746478873239,
|
| 819 |
+
"grad_norm": 2.3741490840911865,
|
| 820 |
+
"learning_rate": 1.5963446475195823e-05,
|
| 821 |
+
"loss": 0.2145,
|
| 822 |
+
"step": 1110
|
| 823 |
+
},
|
| 824 |
+
{
|
| 825 |
+
"epoch": 4.195305164319249,
|
| 826 |
+
"grad_norm": 2.653414011001587,
|
| 827 |
+
"learning_rate": 1.5806788511749348e-05,
|
| 828 |
+
"loss": 0.1701,
|
| 829 |
+
"step": 1120
|
| 830 |
+
},
|
| 831 |
+
{
|
| 832 |
+
"epoch": 4.232863849765258,
|
| 833 |
+
"grad_norm": 3.444087266921997,
|
| 834 |
+
"learning_rate": 1.5650130548302875e-05,
|
| 835 |
+
"loss": 0.2047,
|
| 836 |
+
"step": 1130
|
| 837 |
+
},
|
| 838 |
+
{
|
| 839 |
+
"epoch": 4.270422535211267,
|
| 840 |
+
"grad_norm": 2.024235486984253,
|
| 841 |
+
"learning_rate": 1.54934725848564e-05,
|
| 842 |
+
"loss": 0.1817,
|
| 843 |
+
"step": 1140
|
| 844 |
+
},
|
| 845 |
+
{
|
| 846 |
+
"epoch": 4.307981220657277,
|
| 847 |
+
"grad_norm": 2.742171049118042,
|
| 848 |
+
"learning_rate": 1.533681462140992e-05,
|
| 849 |
+
"loss": 0.1723,
|
| 850 |
+
"step": 1150
|
| 851 |
+
},
|
| 852 |
+
{
|
| 853 |
+
"epoch": 4.345539906103286,
|
| 854 |
+
"grad_norm": 3.3700480461120605,
|
| 855 |
+
"learning_rate": 1.5180156657963446e-05,
|
| 856 |
+
"loss": 0.17,
|
| 857 |
+
"step": 1160
|
| 858 |
+
},
|
| 859 |
+
{
|
| 860 |
+
"epoch": 4.383098591549296,
|
| 861 |
+
"grad_norm": 2.552915573120117,
|
| 862 |
+
"learning_rate": 1.5023498694516973e-05,
|
| 863 |
+
"loss": 0.1802,
|
| 864 |
+
"step": 1170
|
| 865 |
+
},
|
| 866 |
+
{
|
| 867 |
+
"epoch": 4.420657276995305,
|
| 868 |
+
"grad_norm": 3.3317511081695557,
|
| 869 |
+
"learning_rate": 1.4866840731070497e-05,
|
| 870 |
+
"loss": 0.1933,
|
| 871 |
+
"step": 1180
|
| 872 |
+
},
|
| 873 |
+
{
|
| 874 |
+
"epoch": 4.458215962441314,
|
| 875 |
+
"grad_norm": 1.9266548156738281,
|
| 876 |
+
"learning_rate": 1.4710182767624021e-05,
|
| 877 |
+
"loss": 0.1739,
|
| 878 |
+
"step": 1190
|
| 879 |
+
},
|
| 880 |
+
{
|
| 881 |
+
"epoch": 4.495774647887324,
|
| 882 |
+
"grad_norm": 2.1459243297576904,
|
| 883 |
+
"learning_rate": 1.4553524804177547e-05,
|
| 884 |
+
"loss": 0.1599,
|
| 885 |
+
"step": 1200
|
| 886 |
+
},
|
| 887 |
+
{
|
| 888 |
+
"epoch": 4.533333333333333,
|
| 889 |
+
"grad_norm": 3.9314770698547363,
|
| 890 |
+
"learning_rate": 1.4396866840731071e-05,
|
| 891 |
+
"loss": 0.1958,
|
| 892 |
+
"step": 1210
|
| 893 |
+
},
|
| 894 |
+
{
|
| 895 |
+
"epoch": 4.570892018779343,
|
| 896 |
+
"grad_norm": 2.6377363204956055,
|
| 897 |
+
"learning_rate": 1.4240208877284597e-05,
|
| 898 |
+
"loss": 0.1604,
|
| 899 |
+
"step": 1220
|
| 900 |
+
},
|
| 901 |
+
{
|
| 902 |
+
"epoch": 4.608450704225352,
|
| 903 |
+
"grad_norm": 2.810866594314575,
|
| 904 |
+
"learning_rate": 1.408355091383812e-05,
|
| 905 |
+
"loss": 0.1495,
|
| 906 |
+
"step": 1230
|
| 907 |
+
},
|
| 908 |
+
{
|
| 909 |
+
"epoch": 4.646009389671361,
|
| 910 |
+
"grad_norm": 2.2084455490112305,
|
| 911 |
+
"learning_rate": 1.3926892950391646e-05,
|
| 912 |
+
"loss": 0.185,
|
| 913 |
+
"step": 1240
|
| 914 |
+
},
|
| 915 |
+
{
|
| 916 |
+
"epoch": 4.683568075117371,
|
| 917 |
+
"grad_norm": 2.7217283248901367,
|
| 918 |
+
"learning_rate": 1.377023498694517e-05,
|
| 919 |
+
"loss": 0.1757,
|
| 920 |
+
"step": 1250
|
| 921 |
+
},
|
| 922 |
+
{
|
| 923 |
+
"epoch": 4.72112676056338,
|
| 924 |
+
"grad_norm": 3.075267791748047,
|
| 925 |
+
"learning_rate": 1.3613577023498696e-05,
|
| 926 |
+
"loss": 0.1814,
|
| 927 |
+
"step": 1260
|
| 928 |
+
},
|
| 929 |
+
{
|
| 930 |
+
"epoch": 4.758685446009389,
|
| 931 |
+
"grad_norm": 3.2452406883239746,
|
| 932 |
+
"learning_rate": 1.345691906005222e-05,
|
| 933 |
+
"loss": 0.1622,
|
| 934 |
+
"step": 1270
|
| 935 |
+
},
|
| 936 |
+
{
|
| 937 |
+
"epoch": 4.796244131455399,
|
| 938 |
+
"grad_norm": 2.712754487991333,
|
| 939 |
+
"learning_rate": 1.3300261096605744e-05,
|
| 940 |
+
"loss": 0.1714,
|
| 941 |
+
"step": 1280
|
| 942 |
+
},
|
| 943 |
+
{
|
| 944 |
+
"epoch": 4.833802816901408,
|
| 945 |
+
"grad_norm": 1.6795600652694702,
|
| 946 |
+
"learning_rate": 1.3143603133159269e-05,
|
| 947 |
+
"loss": 0.1519,
|
| 948 |
+
"step": 1290
|
| 949 |
+
},
|
| 950 |
+
{
|
| 951 |
+
"epoch": 4.871361502347418,
|
| 952 |
+
"grad_norm": 3.9085493087768555,
|
| 953 |
+
"learning_rate": 1.2986945169712793e-05,
|
| 954 |
+
"loss": 0.1758,
|
| 955 |
+
"step": 1300
|
| 956 |
+
},
|
| 957 |
+
{
|
| 958 |
+
"epoch": 4.908920187793427,
|
| 959 |
+
"grad_norm": 3.529478073120117,
|
| 960 |
+
"learning_rate": 1.2830287206266318e-05,
|
| 961 |
+
"loss": 0.1549,
|
| 962 |
+
"step": 1310
|
| 963 |
+
},
|
| 964 |
+
{
|
| 965 |
+
"epoch": 4.946478873239436,
|
| 966 |
+
"grad_norm": 2.559157609939575,
|
| 967 |
+
"learning_rate": 1.2673629242819842e-05,
|
| 968 |
+
"loss": 0.1824,
|
| 969 |
+
"step": 1320
|
| 970 |
+
},
|
| 971 |
+
{
|
| 972 |
+
"epoch": 4.984037558685446,
|
| 973 |
+
"grad_norm": 2.2350497245788574,
|
| 974 |
+
"learning_rate": 1.2516971279373368e-05,
|
| 975 |
+
"loss": 0.1723,
|
| 976 |
+
"step": 1330
|
| 977 |
+
},
|
| 978 |
+
{
|
| 979 |
+
"epoch": 5.0,
|
| 980 |
+
"eval_accuracy": 0.9826419535157399,
|
| 981 |
+
"eval_loss": 0.09542840719223022,
|
| 982 |
+
"eval_runtime": 5.0389,
|
| 983 |
+
"eval_samples_per_second": 1349.105,
|
| 984 |
+
"eval_steps_per_second": 42.271,
|
| 985 |
+
"step": 1335
|
| 986 |
+
},
|
| 987 |
+
{
|
| 988 |
+
"epoch": 5.018779342723005,
|
| 989 |
+
"grad_norm": 2.5073907375335693,
|
| 990 |
+
"learning_rate": 1.2360313315926892e-05,
|
| 991 |
+
"loss": 0.1401,
|
| 992 |
+
"step": 1340
|
| 993 |
+
},
|
| 994 |
+
{
|
| 995 |
+
"epoch": 5.056338028169014,
|
| 996 |
+
"grad_norm": 4.696757793426514,
|
| 997 |
+
"learning_rate": 1.2203655352480418e-05,
|
| 998 |
+
"loss": 0.1801,
|
| 999 |
+
"step": 1350
|
| 1000 |
+
},
|
| 1001 |
+
{
|
| 1002 |
+
"epoch": 5.093896713615023,
|
| 1003 |
+
"grad_norm": 1.2180489301681519,
|
| 1004 |
+
"learning_rate": 1.2046997389033942e-05,
|
| 1005 |
+
"loss": 0.1335,
|
| 1006 |
+
"step": 1360
|
| 1007 |
+
},
|
| 1008 |
+
{
|
| 1009 |
+
"epoch": 5.131455399061033,
|
| 1010 |
+
"grad_norm": 0.887860119342804,
|
| 1011 |
+
"learning_rate": 1.1890339425587468e-05,
|
| 1012 |
+
"loss": 0.1479,
|
| 1013 |
+
"step": 1370
|
| 1014 |
+
},
|
| 1015 |
+
{
|
| 1016 |
+
"epoch": 5.169014084507042,
|
| 1017 |
+
"grad_norm": 3.6347432136535645,
|
| 1018 |
+
"learning_rate": 1.1733681462140992e-05,
|
| 1019 |
+
"loss": 0.1575,
|
| 1020 |
+
"step": 1380
|
| 1021 |
+
},
|
| 1022 |
+
{
|
| 1023 |
+
"epoch": 5.206572769953052,
|
| 1024 |
+
"grad_norm": 2.901700496673584,
|
| 1025 |
+
"learning_rate": 1.1577023498694518e-05,
|
| 1026 |
+
"loss": 0.1367,
|
| 1027 |
+
"step": 1390
|
| 1028 |
+
},
|
| 1029 |
+
{
|
| 1030 |
+
"epoch": 5.244131455399061,
|
| 1031 |
+
"grad_norm": 2.6395390033721924,
|
| 1032 |
+
"learning_rate": 1.1420365535248042e-05,
|
| 1033 |
+
"loss": 0.144,
|
| 1034 |
+
"step": 1400
|
| 1035 |
+
},
|
| 1036 |
+
{
|
| 1037 |
+
"epoch": 5.28169014084507,
|
| 1038 |
+
"grad_norm": 3.923652172088623,
|
| 1039 |
+
"learning_rate": 1.1263707571801567e-05,
|
| 1040 |
+
"loss": 0.1576,
|
| 1041 |
+
"step": 1410
|
| 1042 |
+
},
|
| 1043 |
+
{
|
| 1044 |
+
"epoch": 5.31924882629108,
|
| 1045 |
+
"grad_norm": 2.290224313735962,
|
| 1046 |
+
"learning_rate": 1.1107049608355092e-05,
|
| 1047 |
+
"loss": 0.16,
|
| 1048 |
+
"step": 1420
|
| 1049 |
+
},
|
| 1050 |
+
{
|
| 1051 |
+
"epoch": 5.356807511737089,
|
| 1052 |
+
"grad_norm": 2.332317590713501,
|
| 1053 |
+
"learning_rate": 1.0950391644908617e-05,
|
| 1054 |
+
"loss": 0.1505,
|
| 1055 |
+
"step": 1430
|
| 1056 |
+
},
|
| 1057 |
+
{
|
| 1058 |
+
"epoch": 5.394366197183099,
|
| 1059 |
+
"grad_norm": 3.474155902862549,
|
| 1060 |
+
"learning_rate": 1.0793733681462141e-05,
|
| 1061 |
+
"loss": 0.1828,
|
| 1062 |
+
"step": 1440
|
| 1063 |
+
},
|
| 1064 |
+
{
|
| 1065 |
+
"epoch": 5.431924882629108,
|
| 1066 |
+
"grad_norm": 2.5219180583953857,
|
| 1067 |
+
"learning_rate": 1.0637075718015665e-05,
|
| 1068 |
+
"loss": 0.1563,
|
| 1069 |
+
"step": 1450
|
| 1070 |
+
},
|
| 1071 |
+
{
|
| 1072 |
+
"epoch": 5.469483568075117,
|
| 1073 |
+
"grad_norm": 4.863851547241211,
|
| 1074 |
+
"learning_rate": 1.0480417754569191e-05,
|
| 1075 |
+
"loss": 0.1308,
|
| 1076 |
+
"step": 1460
|
| 1077 |
+
},
|
| 1078 |
+
{
|
| 1079 |
+
"epoch": 5.507042253521127,
|
| 1080 |
+
"grad_norm": 4.817688941955566,
|
| 1081 |
+
"learning_rate": 1.0323759791122715e-05,
|
| 1082 |
+
"loss": 0.1757,
|
| 1083 |
+
"step": 1470
|
| 1084 |
+
},
|
| 1085 |
+
{
|
| 1086 |
+
"epoch": 5.544600938967136,
|
| 1087 |
+
"grad_norm": 3.194732189178467,
|
| 1088 |
+
"learning_rate": 1.0167101827676241e-05,
|
| 1089 |
+
"loss": 0.1577,
|
| 1090 |
+
"step": 1480
|
| 1091 |
+
},
|
| 1092 |
+
{
|
| 1093 |
+
"epoch": 5.582159624413146,
|
| 1094 |
+
"grad_norm": 3.6605474948883057,
|
| 1095 |
+
"learning_rate": 1.0010443864229765e-05,
|
| 1096 |
+
"loss": 0.2044,
|
| 1097 |
+
"step": 1490
|
| 1098 |
+
},
|
| 1099 |
+
{
|
| 1100 |
+
"epoch": 5.619718309859155,
|
| 1101 |
+
"grad_norm": 2.427701473236084,
|
| 1102 |
+
"learning_rate": 9.853785900783291e-06,
|
| 1103 |
+
"loss": 0.1574,
|
| 1104 |
+
"step": 1500
|
| 1105 |
+
},
|
| 1106 |
+
{
|
| 1107 |
+
"epoch": 5.657276995305164,
|
| 1108 |
+
"grad_norm": 2.8025519847869873,
|
| 1109 |
+
"learning_rate": 9.697127937336815e-06,
|
| 1110 |
+
"loss": 0.188,
|
| 1111 |
+
"step": 1510
|
| 1112 |
+
},
|
| 1113 |
+
{
|
| 1114 |
+
"epoch": 5.694835680751174,
|
| 1115 |
+
"grad_norm": 2.042407989501953,
|
| 1116 |
+
"learning_rate": 9.54046997389034e-06,
|
| 1117 |
+
"loss": 0.1639,
|
| 1118 |
+
"step": 1520
|
| 1119 |
+
},
|
| 1120 |
+
{
|
| 1121 |
+
"epoch": 5.732394366197183,
|
| 1122 |
+
"grad_norm": 4.5383477210998535,
|
| 1123 |
+
"learning_rate": 9.383812010443865e-06,
|
| 1124 |
+
"loss": 0.1641,
|
| 1125 |
+
"step": 1530
|
| 1126 |
+
},
|
| 1127 |
+
{
|
| 1128 |
+
"epoch": 5.769953051643192,
|
| 1129 |
+
"grad_norm": 2.919588804244995,
|
| 1130 |
+
"learning_rate": 9.22715404699739e-06,
|
| 1131 |
+
"loss": 0.1374,
|
| 1132 |
+
"step": 1540
|
| 1133 |
+
},
|
| 1134 |
+
{
|
| 1135 |
+
"epoch": 5.807511737089202,
|
| 1136 |
+
"grad_norm": 2.4344029426574707,
|
| 1137 |
+
"learning_rate": 9.070496083550915e-06,
|
| 1138 |
+
"loss": 0.1711,
|
| 1139 |
+
"step": 1550
|
| 1140 |
+
},
|
| 1141 |
+
{
|
| 1142 |
+
"epoch": 5.845070422535211,
|
| 1143 |
+
"grad_norm": 1.5614906549453735,
|
| 1144 |
+
"learning_rate": 8.913838120104439e-06,
|
| 1145 |
+
"loss": 0.1624,
|
| 1146 |
+
"step": 1560
|
| 1147 |
+
},
|
| 1148 |
+
{
|
| 1149 |
+
"epoch": 5.882629107981221,
|
| 1150 |
+
"grad_norm": 3.0189967155456543,
|
| 1151 |
+
"learning_rate": 8.757180156657963e-06,
|
| 1152 |
+
"loss": 0.1691,
|
| 1153 |
+
"step": 1570
|
| 1154 |
+
},
|
| 1155 |
+
{
|
| 1156 |
+
"epoch": 5.92018779342723,
|
| 1157 |
+
"grad_norm": 2.44000506401062,
|
| 1158 |
+
"learning_rate": 8.600522193211488e-06,
|
| 1159 |
+
"loss": 0.1513,
|
| 1160 |
+
"step": 1580
|
| 1161 |
+
},
|
| 1162 |
+
{
|
| 1163 |
+
"epoch": 5.957746478873239,
|
| 1164 |
+
"grad_norm": 2.4327423572540283,
|
| 1165 |
+
"learning_rate": 8.443864229765013e-06,
|
| 1166 |
+
"loss": 0.1538,
|
| 1167 |
+
"step": 1590
|
| 1168 |
+
},
|
| 1169 |
+
{
|
| 1170 |
+
"epoch": 5.995305164319249,
|
| 1171 |
+
"grad_norm": 2.1192240715026855,
|
| 1172 |
+
"learning_rate": 8.287206266318538e-06,
|
| 1173 |
+
"loss": 0.1442,
|
| 1174 |
+
"step": 1600
|
| 1175 |
+
},
|
| 1176 |
+
{
|
| 1177 |
+
"epoch": 6.0,
|
| 1178 |
+
"eval_accuracy": 0.981318034716093,
|
| 1179 |
+
"eval_loss": 0.09270217269659042,
|
| 1180 |
+
"eval_runtime": 4.8524,
|
| 1181 |
+
"eval_samples_per_second": 1400.961,
|
| 1182 |
+
"eval_steps_per_second": 43.896,
|
| 1183 |
+
"step": 1602
|
| 1184 |
+
},
|
| 1185 |
+
{
|
| 1186 |
+
"epoch": 6.030046948356808,
|
| 1187 |
+
"grad_norm": 1.8678548336029053,
|
| 1188 |
+
"learning_rate": 8.130548302872062e-06,
|
| 1189 |
+
"loss": 0.1328,
|
| 1190 |
+
"step": 1610
|
| 1191 |
+
},
|
| 1192 |
+
{
|
| 1193 |
+
"epoch": 6.067605633802817,
|
| 1194 |
+
"grad_norm": 3.0712783336639404,
|
| 1195 |
+
"learning_rate": 7.973890339425586e-06,
|
| 1196 |
+
"loss": 0.1543,
|
| 1197 |
+
"step": 1620
|
| 1198 |
+
},
|
| 1199 |
+
{
|
| 1200 |
+
"epoch": 6.105164319248826,
|
| 1201 |
+
"grad_norm": 4.49588680267334,
|
| 1202 |
+
"learning_rate": 7.817232375979112e-06,
|
| 1203 |
+
"loss": 0.1452,
|
| 1204 |
+
"step": 1630
|
| 1205 |
+
},
|
| 1206 |
+
{
|
| 1207 |
+
"epoch": 6.142723004694836,
|
| 1208 |
+
"grad_norm": 3.9594759941101074,
|
| 1209 |
+
"learning_rate": 7.660574412532636e-06,
|
| 1210 |
+
"loss": 0.1513,
|
| 1211 |
+
"step": 1640
|
| 1212 |
+
},
|
| 1213 |
+
{
|
| 1214 |
+
"epoch": 6.180281690140845,
|
| 1215 |
+
"grad_norm": 2.528153657913208,
|
| 1216 |
+
"learning_rate": 7.503916449086162e-06,
|
| 1217 |
+
"loss": 0.1589,
|
| 1218 |
+
"step": 1650
|
| 1219 |
+
},
|
| 1220 |
+
{
|
| 1221 |
+
"epoch": 6.217840375586855,
|
| 1222 |
+
"grad_norm": 2.159458875656128,
|
| 1223 |
+
"learning_rate": 7.347258485639687e-06,
|
| 1224 |
+
"loss": 0.1443,
|
| 1225 |
+
"step": 1660
|
| 1226 |
+
},
|
| 1227 |
+
{
|
| 1228 |
+
"epoch": 6.255399061032864,
|
| 1229 |
+
"grad_norm": 2.098022222518921,
|
| 1230 |
+
"learning_rate": 7.190600522193212e-06,
|
| 1231 |
+
"loss": 0.1564,
|
| 1232 |
+
"step": 1670
|
| 1233 |
+
},
|
| 1234 |
+
{
|
| 1235 |
+
"epoch": 6.292957746478873,
|
| 1236 |
+
"grad_norm": 1.993698239326477,
|
| 1237 |
+
"learning_rate": 7.033942558746737e-06,
|
| 1238 |
+
"loss": 0.1401,
|
| 1239 |
+
"step": 1680
|
| 1240 |
+
},
|
| 1241 |
+
{
|
| 1242 |
+
"epoch": 6.330516431924883,
|
| 1243 |
+
"grad_norm": 2.2639145851135254,
|
| 1244 |
+
"learning_rate": 6.877284595300262e-06,
|
| 1245 |
+
"loss": 0.1452,
|
| 1246 |
+
"step": 1690
|
| 1247 |
+
},
|
| 1248 |
+
{
|
| 1249 |
+
"epoch": 6.368075117370892,
|
| 1250 |
+
"grad_norm": 2.5003936290740967,
|
| 1251 |
+
"learning_rate": 6.720626631853786e-06,
|
| 1252 |
+
"loss": 0.1439,
|
| 1253 |
+
"step": 1700
|
| 1254 |
+
},
|
| 1255 |
+
{
|
| 1256 |
+
"epoch": 6.405633802816902,
|
| 1257 |
+
"grad_norm": 2.0841052532196045,
|
| 1258 |
+
"learning_rate": 6.563968668407311e-06,
|
| 1259 |
+
"loss": 0.1438,
|
| 1260 |
+
"step": 1710
|
| 1261 |
+
},
|
| 1262 |
+
{
|
| 1263 |
+
"epoch": 6.443192488262911,
|
| 1264 |
+
"grad_norm": 3.550182819366455,
|
| 1265 |
+
"learning_rate": 6.4073107049608355e-06,
|
| 1266 |
+
"loss": 0.1433,
|
| 1267 |
+
"step": 1720
|
| 1268 |
+
},
|
| 1269 |
+
{
|
| 1270 |
+
"epoch": 6.48075117370892,
|
| 1271 |
+
"grad_norm": 1.4857251644134521,
|
| 1272 |
+
"learning_rate": 6.2506527415143605e-06,
|
| 1273 |
+
"loss": 0.1404,
|
| 1274 |
+
"step": 1730
|
| 1275 |
+
},
|
| 1276 |
+
{
|
| 1277 |
+
"epoch": 6.51830985915493,
|
| 1278 |
+
"grad_norm": 3.503309726715088,
|
| 1279 |
+
"learning_rate": 6.093994778067885e-06,
|
| 1280 |
+
"loss": 0.1493,
|
| 1281 |
+
"step": 1740
|
| 1282 |
+
},
|
| 1283 |
+
{
|
| 1284 |
+
"epoch": 6.555868544600939,
|
| 1285 |
+
"grad_norm": 3.59545636177063,
|
| 1286 |
+
"learning_rate": 5.93733681462141e-06,
|
| 1287 |
+
"loss": 0.1563,
|
| 1288 |
+
"step": 1750
|
| 1289 |
+
},
|
| 1290 |
+
{
|
| 1291 |
+
"epoch": 6.593427230046949,
|
| 1292 |
+
"grad_norm": 2.879582405090332,
|
| 1293 |
+
"learning_rate": 5.780678851174934e-06,
|
| 1294 |
+
"loss": 0.122,
|
| 1295 |
+
"step": 1760
|
| 1296 |
+
},
|
| 1297 |
+
{
|
| 1298 |
+
"epoch": 6.630985915492958,
|
| 1299 |
+
"grad_norm": 1.7240543365478516,
|
| 1300 |
+
"learning_rate": 5.624020887728459e-06,
|
| 1301 |
+
"loss": 0.1404,
|
| 1302 |
+
"step": 1770
|
| 1303 |
+
},
|
| 1304 |
+
{
|
| 1305 |
+
"epoch": 6.668544600938967,
|
| 1306 |
+
"grad_norm": 3.0438528060913086,
|
| 1307 |
+
"learning_rate": 5.467362924281984e-06,
|
| 1308 |
+
"loss": 0.1432,
|
| 1309 |
+
"step": 1780
|
| 1310 |
+
},
|
| 1311 |
+
{
|
| 1312 |
+
"epoch": 6.706103286384977,
|
| 1313 |
+
"grad_norm": 2.496366024017334,
|
| 1314 |
+
"learning_rate": 5.310704960835509e-06,
|
| 1315 |
+
"loss": 0.1277,
|
| 1316 |
+
"step": 1790
|
| 1317 |
+
},
|
| 1318 |
+
{
|
| 1319 |
+
"epoch": 6.743661971830986,
|
| 1320 |
+
"grad_norm": 1.7166277170181274,
|
| 1321 |
+
"learning_rate": 5.154046997389034e-06,
|
| 1322 |
+
"loss": 0.143,
|
| 1323 |
+
"step": 1800
|
| 1324 |
+
},
|
| 1325 |
+
{
|
| 1326 |
+
"epoch": 6.781220657276995,
|
| 1327 |
+
"grad_norm": 2.4547784328460693,
|
| 1328 |
+
"learning_rate": 4.997389033942559e-06,
|
| 1329 |
+
"loss": 0.1198,
|
| 1330 |
+
"step": 1810
|
| 1331 |
+
},
|
| 1332 |
+
{
|
| 1333 |
+
"epoch": 6.818779342723005,
|
| 1334 |
+
"grad_norm": 2.604220390319824,
|
| 1335 |
+
"learning_rate": 4.840731070496084e-06,
|
| 1336 |
+
"loss": 0.1705,
|
| 1337 |
+
"step": 1820
|
| 1338 |
+
},
|
| 1339 |
+
{
|
| 1340 |
+
"epoch": 6.856338028169014,
|
| 1341 |
+
"grad_norm": 2.7237601280212402,
|
| 1342 |
+
"learning_rate": 4.684073107049609e-06,
|
| 1343 |
+
"loss": 0.1506,
|
| 1344 |
+
"step": 1830
|
| 1345 |
+
},
|
| 1346 |
+
{
|
| 1347 |
+
"epoch": 6.893896713615024,
|
| 1348 |
+
"grad_norm": 2.638058662414551,
|
| 1349 |
+
"learning_rate": 4.527415143603134e-06,
|
| 1350 |
+
"loss": 0.154,
|
| 1351 |
+
"step": 1840
|
| 1352 |
+
},
|
| 1353 |
+
{
|
| 1354 |
+
"epoch": 6.931455399061033,
|
| 1355 |
+
"grad_norm": 3.8382205963134766,
|
| 1356 |
+
"learning_rate": 4.3707571801566586e-06,
|
| 1357 |
+
"loss": 0.1553,
|
| 1358 |
+
"step": 1850
|
| 1359 |
+
},
|
| 1360 |
+
{
|
| 1361 |
+
"epoch": 6.969014084507043,
|
| 1362 |
+
"grad_norm": 2.071164131164551,
|
| 1363 |
+
"learning_rate": 4.2140992167101835e-06,
|
| 1364 |
+
"loss": 0.1397,
|
| 1365 |
+
"step": 1860
|
| 1366 |
+
},
|
| 1367 |
+
{
|
| 1368 |
+
"epoch": 7.0,
|
| 1369 |
+
"eval_accuracy": 0.9811709326272433,
|
| 1370 |
+
"eval_loss": 0.08920056372880936,
|
| 1371 |
+
"eval_runtime": 4.9166,
|
| 1372 |
+
"eval_samples_per_second": 1382.662,
|
| 1373 |
+
"eval_steps_per_second": 43.323,
|
| 1374 |
+
"step": 1869
|
| 1375 |
+
},
|
| 1376 |
+
{
|
| 1377 |
+
"epoch": 7.003755868544601,
|
| 1378 |
+
"grad_norm": 2.5346381664276123,
|
| 1379 |
+
"learning_rate": 4.0574412532637075e-06,
|
| 1380 |
+
"loss": 0.1296,
|
| 1381 |
+
"step": 1870
|
| 1382 |
+
},
|
| 1383 |
+
{
|
| 1384 |
+
"epoch": 7.041314553990611,
|
| 1385 |
+
"grad_norm": 2.575307846069336,
|
| 1386 |
+
"learning_rate": 3.9007832898172325e-06,
|
| 1387 |
+
"loss": 0.1389,
|
| 1388 |
+
"step": 1880
|
| 1389 |
+
},
|
| 1390 |
+
{
|
| 1391 |
+
"epoch": 7.07887323943662,
|
| 1392 |
+
"grad_norm": 2.0408527851104736,
|
| 1393 |
+
"learning_rate": 3.7441253263707574e-06,
|
| 1394 |
+
"loss": 0.1521,
|
| 1395 |
+
"step": 1890
|
| 1396 |
+
},
|
| 1397 |
+
{
|
| 1398 |
+
"epoch": 7.1164319248826295,
|
| 1399 |
+
"grad_norm": 3.2742061614990234,
|
| 1400 |
+
"learning_rate": 3.5874673629242823e-06,
|
| 1401 |
+
"loss": 0.1342,
|
| 1402 |
+
"step": 1900
|
| 1403 |
+
},
|
| 1404 |
+
{
|
| 1405 |
+
"epoch": 7.153990610328639,
|
| 1406 |
+
"grad_norm": 1.4502960443496704,
|
| 1407 |
+
"learning_rate": 3.4308093994778068e-06,
|
| 1408 |
+
"loss": 0.1204,
|
| 1409 |
+
"step": 1910
|
| 1410 |
+
},
|
| 1411 |
+
{
|
| 1412 |
+
"epoch": 7.191549295774648,
|
| 1413 |
+
"grad_norm": 3.7600743770599365,
|
| 1414 |
+
"learning_rate": 3.2741514360313317e-06,
|
| 1415 |
+
"loss": 0.1431,
|
| 1416 |
+
"step": 1920
|
| 1417 |
+
},
|
| 1418 |
+
{
|
| 1419 |
+
"epoch": 7.229107981220658,
|
| 1420 |
+
"grad_norm": 2.7332417964935303,
|
| 1421 |
+
"learning_rate": 3.1174934725848566e-06,
|
| 1422 |
+
"loss": 0.1281,
|
| 1423 |
+
"step": 1930
|
| 1424 |
+
},
|
| 1425 |
+
{
|
| 1426 |
+
"epoch": 7.266666666666667,
|
| 1427 |
+
"grad_norm": 2.6618921756744385,
|
| 1428 |
+
"learning_rate": 2.960835509138381e-06,
|
| 1429 |
+
"loss": 0.141,
|
| 1430 |
+
"step": 1940
|
| 1431 |
+
},
|
| 1432 |
+
{
|
| 1433 |
+
"epoch": 7.304225352112676,
|
| 1434 |
+
"grad_norm": 3.625688314437866,
|
| 1435 |
+
"learning_rate": 2.804177545691906e-06,
|
| 1436 |
+
"loss": 0.1455,
|
| 1437 |
+
"step": 1950
|
| 1438 |
+
},
|
| 1439 |
+
{
|
| 1440 |
+
"epoch": 7.341784037558686,
|
| 1441 |
+
"grad_norm": 2.0667765140533447,
|
| 1442 |
+
"learning_rate": 2.647519582245431e-06,
|
| 1443 |
+
"loss": 0.1359,
|
| 1444 |
+
"step": 1960
|
| 1445 |
+
},
|
| 1446 |
+
{
|
| 1447 |
+
"epoch": 7.379342723004695,
|
| 1448 |
+
"grad_norm": 2.369652509689331,
|
| 1449 |
+
"learning_rate": 2.490861618798956e-06,
|
| 1450 |
+
"loss": 0.1295,
|
| 1451 |
+
"step": 1970
|
| 1452 |
+
},
|
| 1453 |
+
{
|
| 1454 |
+
"epoch": 7.416901408450705,
|
| 1455 |
+
"grad_norm": 3.836838722229004,
|
| 1456 |
+
"learning_rate": 2.3342036553524807e-06,
|
| 1457 |
+
"loss": 0.1489,
|
| 1458 |
+
"step": 1980
|
| 1459 |
+
},
|
| 1460 |
+
{
|
| 1461 |
+
"epoch": 7.454460093896714,
|
| 1462 |
+
"grad_norm": 3.3261311054229736,
|
| 1463 |
+
"learning_rate": 2.1775456919060052e-06,
|
| 1464 |
+
"loss": 0.1289,
|
| 1465 |
+
"step": 1990
|
| 1466 |
+
},
|
| 1467 |
+
{
|
| 1468 |
+
"epoch": 7.492018779342723,
|
| 1469 |
+
"grad_norm": 2.6514954566955566,
|
| 1470 |
+
"learning_rate": 2.0208877284595297e-06,
|
| 1471 |
+
"loss": 0.1185,
|
| 1472 |
+
"step": 2000
|
| 1473 |
+
},
|
| 1474 |
+
{
|
| 1475 |
+
"epoch": 7.529577464788733,
|
| 1476 |
+
"grad_norm": 2.1017005443573,
|
| 1477 |
+
"learning_rate": 1.8642297650130548e-06,
|
| 1478 |
+
"loss": 0.1472,
|
| 1479 |
+
"step": 2010
|
| 1480 |
+
},
|
| 1481 |
+
{
|
| 1482 |
+
"epoch": 7.567136150234742,
|
| 1483 |
+
"grad_norm": 2.5104258060455322,
|
| 1484 |
+
"learning_rate": 1.7075718015665795e-06,
|
| 1485 |
+
"loss": 0.1467,
|
| 1486 |
+
"step": 2020
|
| 1487 |
+
},
|
| 1488 |
+
{
|
| 1489 |
+
"epoch": 7.6046948356807516,
|
| 1490 |
+
"grad_norm": 1.7915935516357422,
|
| 1491 |
+
"learning_rate": 1.5509138381201045e-06,
|
| 1492 |
+
"loss": 0.1212,
|
| 1493 |
+
"step": 2030
|
| 1494 |
+
},
|
| 1495 |
+
{
|
| 1496 |
+
"epoch": 7.642253521126761,
|
| 1497 |
+
"grad_norm": 2.4937989711761475,
|
| 1498 |
+
"learning_rate": 1.3942558746736294e-06,
|
| 1499 |
+
"loss": 0.1395,
|
| 1500 |
+
"step": 2040
|
| 1501 |
+
},
|
| 1502 |
+
{
|
| 1503 |
+
"epoch": 7.67981220657277,
|
| 1504 |
+
"grad_norm": 2.758594274520874,
|
| 1505 |
+
"learning_rate": 1.237597911227154e-06,
|
| 1506 |
+
"loss": 0.1361,
|
| 1507 |
+
"step": 2050
|
| 1508 |
+
},
|
| 1509 |
+
{
|
| 1510 |
+
"epoch": 7.71737089201878,
|
| 1511 |
+
"grad_norm": 2.291672468185425,
|
| 1512 |
+
"learning_rate": 1.0809399477806788e-06,
|
| 1513 |
+
"loss": 0.1182,
|
| 1514 |
+
"step": 2060
|
| 1515 |
+
},
|
| 1516 |
+
{
|
| 1517 |
+
"epoch": 7.754929577464789,
|
| 1518 |
+
"grad_norm": 1.944736361503601,
|
| 1519 |
+
"learning_rate": 9.242819843342037e-07,
|
| 1520 |
+
"loss": 0.1307,
|
| 1521 |
+
"step": 2070
|
| 1522 |
+
},
|
| 1523 |
+
{
|
| 1524 |
+
"epoch": 7.792488262910798,
|
| 1525 |
+
"grad_norm": 1.448411226272583,
|
| 1526 |
+
"learning_rate": 7.676240208877285e-07,
|
| 1527 |
+
"loss": 0.1407,
|
| 1528 |
+
"step": 2080
|
| 1529 |
+
},
|
| 1530 |
+
{
|
| 1531 |
+
"epoch": 7.830046948356808,
|
| 1532 |
+
"grad_norm": 3.276000499725342,
|
| 1533 |
+
"learning_rate": 6.109660574412533e-07,
|
| 1534 |
+
"loss": 0.1361,
|
| 1535 |
+
"step": 2090
|
| 1536 |
+
},
|
| 1537 |
+
{
|
| 1538 |
+
"epoch": 7.867605633802817,
|
| 1539 |
+
"grad_norm": 3.627788543701172,
|
| 1540 |
+
"learning_rate": 4.5430809399477806e-07,
|
| 1541 |
+
"loss": 0.131,
|
| 1542 |
+
"step": 2100
|
| 1543 |
+
},
|
| 1544 |
+
{
|
| 1545 |
+
"epoch": 7.905164319248827,
|
| 1546 |
+
"grad_norm": 1.2533661127090454,
|
| 1547 |
+
"learning_rate": 2.9765013054830287e-07,
|
| 1548 |
+
"loss": 0.1245,
|
| 1549 |
+
"step": 2110
|
| 1550 |
+
},
|
| 1551 |
+
{
|
| 1552 |
+
"epoch": 7.942723004694836,
|
| 1553 |
+
"grad_norm": 1.472484827041626,
|
| 1554 |
+
"learning_rate": 1.409921671018277e-07,
|
| 1555 |
+
"loss": 0.1368,
|
| 1556 |
+
"step": 2120
|
| 1557 |
+
},
|
| 1558 |
+
{
|
| 1559 |
+
"epoch": 7.972769953051643,
|
| 1560 |
+
"eval_accuracy": 0.9811709326272433,
|
| 1561 |
+
"eval_loss": 0.08957477658987045,
|
| 1562 |
+
"eval_runtime": 5.418,
|
| 1563 |
+
"eval_samples_per_second": 1254.697,
|
| 1564 |
+
"eval_steps_per_second": 39.313,
|
| 1565 |
+
"step": 2128
|
| 1566 |
+
},
|
| 1567 |
+
{
|
| 1568 |
+
"epoch": 7.972769953051643,
|
| 1569 |
+
"step": 2128,
|
| 1570 |
+
"total_flos": 3.767900833756416e+18,
|
| 1571 |
+
"train_loss": 0.5178930132572812,
|
| 1572 |
+
"train_runtime": 756.2923,
|
| 1573 |
+
"train_samples_per_second": 540.468,
|
| 1574 |
+
"train_steps_per_second": 2.814
|
| 1575 |
+
}
|
| 1576 |
+
],
|
| 1577 |
+
"logging_steps": 10,
|
| 1578 |
+
"max_steps": 2128,
|
| 1579 |
+
"num_input_tokens_seen": 0,
|
| 1580 |
+
"num_train_epochs": 8,
|
| 1581 |
+
"save_steps": 500,
|
| 1582 |
+
"stateful_callbacks": {
|
| 1583 |
+
"TrainerControl": {
|
| 1584 |
+
"args": {
|
| 1585 |
+
"should_epoch_stop": false,
|
| 1586 |
+
"should_evaluate": false,
|
| 1587 |
+
"should_log": false,
|
| 1588 |
+
"should_save": true,
|
| 1589 |
+
"should_training_stop": true
|
| 1590 |
+
},
|
| 1591 |
+
"attributes": {}
|
| 1592 |
+
}
|
| 1593 |
+
},
|
| 1594 |
+
"total_flos": 3.767900833756416e+18,
|
| 1595 |
+
"train_batch_size": 48,
|
| 1596 |
+
"trial_name": null,
|
| 1597 |
+
"trial_params": null
|
| 1598 |
+
}
|
hpo-examples/audio-classification/ac/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:11e71a8faf8833e9bfb138217263fb0518314e3b8597902b752b2bc9dd143942
|
| 3 |
+
size 5368
|
hpo-examples/audio-classification/requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets>=1.14.0
|
| 2 |
+
evaluate
|
| 3 |
+
librosa
|
| 4 |
+
torchaudio
|
| 5 |
+
torch>=1.6
|
hpo-examples/audio-classification/run.sh
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CUDA_VISIBLE_DEVICES=0 python run_audio_classification.py \
|
| 2 |
+
--model_name_or_path facebook/wav2vec2-base \
|
| 3 |
+
--dataset_name superb \
|
| 4 |
+
--dataset_config_name ks \
|
| 5 |
+
--trust_remote_code \
|
| 6 |
+
--output_dir wav2vec2-base-ft-keyword-spotting \
|
| 7 |
+
--overwrite_output_dir \
|
| 8 |
+
--remove_unused_columns False \
|
| 9 |
+
--do_train \
|
| 10 |
+
--do_eval \
|
| 11 |
+
--fp16 \
|
| 12 |
+
--learning_rate 3e-5 \
|
| 13 |
+
--max_length_seconds 1 \
|
| 14 |
+
--attention_mask False \
|
| 15 |
+
--warmup_ratio 0.1 \
|
| 16 |
+
--num_train_epochs 8 \
|
| 17 |
+
--per_device_train_batch_size 64 \
|
| 18 |
+
--gradient_accumulation_steps 4 \
|
| 19 |
+
--per_device_eval_batch_size 32 \
|
| 20 |
+
--dataloader_num_workers 4 \
|
| 21 |
+
--logging_strategy steps \
|
| 22 |
+
--logging_steps 10 \
|
| 23 |
+
--eval_strategy epoch \
|
| 24 |
+
--save_strategy epoch \
|
| 25 |
+
--load_best_model_at_end True \
|
| 26 |
+
--metric_for_best_model accuracy \
|
| 27 |
+
--save_total_limit 3 \
|
| 28 |
+
--seed 0 \
|
| 29 |
+
--push_to_hub \
|
| 30 |
+
--apply-trp --trp-depths 1 --trp-p 0.1 --trp-lambdas 0.4 0.2 0.1
|
hpo-examples/audio-classification/run_audio_classification.py
ADDED
|
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
import os
|
| 19 |
+
import sys
|
| 20 |
+
import warnings
|
| 21 |
+
from dataclasses import dataclass, field
|
| 22 |
+
from random import randint
|
| 23 |
+
from typing import Optional, List
|
| 24 |
+
|
| 25 |
+
import datasets
|
| 26 |
+
import evaluate
|
| 27 |
+
import numpy as np
|
| 28 |
+
from datasets import DatasetDict, load_dataset
|
| 29 |
+
|
| 30 |
+
import transformers
|
| 31 |
+
from transformers import (
|
| 32 |
+
AutoConfig,
|
| 33 |
+
AutoFeatureExtractor,
|
| 34 |
+
AutoModelForAudioClassification,
|
| 35 |
+
HfArgumentParser,
|
| 36 |
+
Trainer,
|
| 37 |
+
TrainingArguments,
|
| 38 |
+
set_seed,
|
| 39 |
+
)
|
| 40 |
+
from transformers.trainer_utils import get_last_checkpoint
|
| 41 |
+
from transformers.utils import check_min_version, send_example_telemetry
|
| 42 |
+
from transformers.utils.versions import require_version
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
from trplib import apply_trp
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
logger = logging.getLogger(__name__)
|
| 49 |
+
|
| 50 |
+
# # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
| 51 |
+
# check_min_version("4.50.0.dev0")
|
| 52 |
+
|
| 53 |
+
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 16000):
|
| 57 |
+
"""Randomly sample chunks of `max_length` seconds from the input audio"""
|
| 58 |
+
sample_length = int(round(sample_rate * max_length))
|
| 59 |
+
if len(wav) <= sample_length:
|
| 60 |
+
return wav
|
| 61 |
+
random_offset = randint(0, len(wav) - sample_length - 1)
|
| 62 |
+
return wav[random_offset : random_offset + sample_length]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@dataclass
|
| 66 |
+
class DataTrainingArguments:
|
| 67 |
+
"""
|
| 68 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 69 |
+
Using `HfArgumentParser` we can turn this class
|
| 70 |
+
into argparse arguments to be able to specify them on
|
| 71 |
+
the command line.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
dataset_name: Optional[str] = field(default=None, metadata={"help": "Name of a dataset from the datasets package"})
|
| 75 |
+
dataset_config_name: Optional[str] = field(
|
| 76 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
| 77 |
+
)
|
| 78 |
+
train_file: Optional[str] = field(
|
| 79 |
+
default=None, metadata={"help": "A file containing the training audio paths and labels."}
|
| 80 |
+
)
|
| 81 |
+
eval_file: Optional[str] = field(
|
| 82 |
+
default=None, metadata={"help": "A file containing the validation audio paths and labels."}
|
| 83 |
+
)
|
| 84 |
+
train_split_name: str = field(
|
| 85 |
+
default="train",
|
| 86 |
+
metadata={
|
| 87 |
+
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
|
| 88 |
+
},
|
| 89 |
+
)
|
| 90 |
+
eval_split_name: str = field(
|
| 91 |
+
default="validation",
|
| 92 |
+
metadata={
|
| 93 |
+
"help": (
|
| 94 |
+
"The name of the training data set split to use (via the datasets library). Defaults to 'validation'"
|
| 95 |
+
)
|
| 96 |
+
},
|
| 97 |
+
)
|
| 98 |
+
audio_column_name: str = field(
|
| 99 |
+
default="audio",
|
| 100 |
+
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
|
| 101 |
+
)
|
| 102 |
+
label_column_name: str = field(
|
| 103 |
+
default="label", metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'"}
|
| 104 |
+
)
|
| 105 |
+
max_train_samples: Optional[int] = field(
|
| 106 |
+
default=None,
|
| 107 |
+
metadata={
|
| 108 |
+
"help": (
|
| 109 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 110 |
+
"value if set."
|
| 111 |
+
)
|
| 112 |
+
},
|
| 113 |
+
)
|
| 114 |
+
max_eval_samples: Optional[int] = field(
|
| 115 |
+
default=None,
|
| 116 |
+
metadata={
|
| 117 |
+
"help": (
|
| 118 |
+
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
| 119 |
+
"value if set."
|
| 120 |
+
)
|
| 121 |
+
},
|
| 122 |
+
)
|
| 123 |
+
max_length_seconds: float = field(
|
| 124 |
+
default=20,
|
| 125 |
+
metadata={"help": "Audio clips will be randomly cut to this length during training if the value is set."},
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@dataclass
|
| 130 |
+
class ModelArguments:
|
| 131 |
+
"""
|
| 132 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
model_name_or_path: str = field(
|
| 136 |
+
default="facebook/wav2vec2-base",
|
| 137 |
+
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"},
|
| 138 |
+
)
|
| 139 |
+
config_name: Optional[str] = field(
|
| 140 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
| 141 |
+
)
|
| 142 |
+
cache_dir: Optional[str] = field(
|
| 143 |
+
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from the Hub"}
|
| 144 |
+
)
|
| 145 |
+
model_revision: str = field(
|
| 146 |
+
default="main",
|
| 147 |
+
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
| 148 |
+
)
|
| 149 |
+
feature_extractor_name: Optional[str] = field(
|
| 150 |
+
default=None, metadata={"help": "Name or path of preprocessor config."}
|
| 151 |
+
)
|
| 152 |
+
freeze_feature_encoder: bool = field(
|
| 153 |
+
default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
|
| 154 |
+
)
|
| 155 |
+
attention_mask: bool = field(
|
| 156 |
+
default=True, metadata={"help": "Whether to generate an attention mask in the feature extractor."}
|
| 157 |
+
)
|
| 158 |
+
token: str = field(
|
| 159 |
+
default=None,
|
| 160 |
+
metadata={
|
| 161 |
+
"help": (
|
| 162 |
+
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
|
| 163 |
+
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
|
| 164 |
+
)
|
| 165 |
+
},
|
| 166 |
+
)
|
| 167 |
+
trust_remote_code: bool = field(
|
| 168 |
+
default=False,
|
| 169 |
+
metadata={
|
| 170 |
+
"help": (
|
| 171 |
+
"Whether to trust the execution of code from datasets/models defined on the Hub."
|
| 172 |
+
" This option should only be set to `True` for repositories you trust and in which you have read the"
|
| 173 |
+
" code, as it will execute code present on the Hub on your local machine."
|
| 174 |
+
)
|
| 175 |
+
},
|
| 176 |
+
)
|
| 177 |
+
freeze_feature_extractor: Optional[bool] = field(
|
| 178 |
+
default=None, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
|
| 179 |
+
)
|
| 180 |
+
ignore_mismatched_sizes: bool = field(
|
| 181 |
+
default=False,
|
| 182 |
+
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
apply_trp: Optional[bool] = field(
|
| 186 |
+
default=False,
|
| 187 |
+
metadata={"help": "Whether to apply trp or not."},
|
| 188 |
+
)
|
| 189 |
+
trp_depths: Optional[int] = field(
|
| 190 |
+
default=1,
|
| 191 |
+
metadata={
|
| 192 |
+
"help": "TRP depth value."
|
| 193 |
+
},
|
| 194 |
+
)
|
| 195 |
+
trp_p: Optional[float] = field(
|
| 196 |
+
default=0.1,
|
| 197 |
+
metadata={
|
| 198 |
+
"help": "TRP p value."
|
| 199 |
+
},
|
| 200 |
+
)
|
| 201 |
+
trp_lambdas: Optional[List[float]] = field(
|
| 202 |
+
default_factory=lambda: [0.4, 0.2, 0.1],
|
| 203 |
+
metadata={
|
| 204 |
+
"help": "TRP lambda values (list of floats)."
|
| 205 |
+
},
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
def __post_init__(self):
|
| 209 |
+
if not self.freeze_feature_extractor and self.freeze_feature_encoder:
|
| 210 |
+
warnings.warn(
|
| 211 |
+
"The argument `--freeze_feature_extractor` is deprecated and "
|
| 212 |
+
"will be removed in a future version. Use `--freeze_feature_encoder` "
|
| 213 |
+
"instead. Setting `freeze_feature_encoder==True`.",
|
| 214 |
+
FutureWarning,
|
| 215 |
+
)
|
| 216 |
+
if self.freeze_feature_extractor and not self.freeze_feature_encoder:
|
| 217 |
+
raise ValueError(
|
| 218 |
+
"The argument `--freeze_feature_extractor` is deprecated and "
|
| 219 |
+
"should not be used in combination with `--freeze_feature_encoder`. "
|
| 220 |
+
"Only make use of `--freeze_feature_encoder`."
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def main():
|
| 225 |
+
# See all possible arguments in src/transformers/training_args.py
|
| 226 |
+
# or by passing the --help flag to this script.
|
| 227 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
| 228 |
+
|
| 229 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
| 230 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
| 231 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
| 232 |
+
# let's parse it to get our arguments.
|
| 233 |
+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
| 234 |
+
else:
|
| 235 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 236 |
+
|
| 237 |
+
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
| 238 |
+
# information sent is the one passed as arguments along with your Python/PyTorch versions.
|
| 239 |
+
send_example_telemetry("run_audio_classification", model_args, data_args)
|
| 240 |
+
|
| 241 |
+
# Setup logging
|
| 242 |
+
logging.basicConfig(
|
| 243 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 244 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 245 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
if training_args.should_log:
|
| 249 |
+
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
| 250 |
+
transformers.utils.logging.set_verbosity_info()
|
| 251 |
+
|
| 252 |
+
log_level = training_args.get_process_log_level()
|
| 253 |
+
logger.setLevel(log_level)
|
| 254 |
+
transformers.utils.logging.set_verbosity(log_level)
|
| 255 |
+
transformers.utils.logging.enable_default_handler()
|
| 256 |
+
transformers.utils.logging.enable_explicit_format()
|
| 257 |
+
|
| 258 |
+
# Log on each process the small summary:
|
| 259 |
+
logger.warning(
|
| 260 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
|
| 261 |
+
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
|
| 262 |
+
)
|
| 263 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
| 264 |
+
|
| 265 |
+
# Set seed before initializing model.
|
| 266 |
+
set_seed(training_args.seed)
|
| 267 |
+
|
| 268 |
+
# Detecting last checkpoint.
|
| 269 |
+
last_checkpoint = None
|
| 270 |
+
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
| 271 |
+
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
| 272 |
+
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
| 273 |
+
raise ValueError(
|
| 274 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
| 275 |
+
"Use --overwrite_output_dir to train from scratch."
|
| 276 |
+
)
|
| 277 |
+
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
| 278 |
+
logger.info(
|
| 279 |
+
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
| 280 |
+
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# Initialize our dataset and prepare it for the audio classification task.
|
| 284 |
+
raw_datasets = DatasetDict()
|
| 285 |
+
raw_datasets["train"] = load_dataset(
|
| 286 |
+
data_args.dataset_name,
|
| 287 |
+
data_args.dataset_config_name,
|
| 288 |
+
split=data_args.train_split_name,
|
| 289 |
+
token=model_args.token,
|
| 290 |
+
trust_remote_code=model_args.trust_remote_code,
|
| 291 |
+
)
|
| 292 |
+
raw_datasets["eval"] = load_dataset(
|
| 293 |
+
data_args.dataset_name,
|
| 294 |
+
data_args.dataset_config_name,
|
| 295 |
+
split=data_args.eval_split_name,
|
| 296 |
+
token=model_args.token,
|
| 297 |
+
trust_remote_code=model_args.trust_remote_code,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
if data_args.audio_column_name not in raw_datasets["train"].column_names:
|
| 301 |
+
raise ValueError(
|
| 302 |
+
f"--audio_column_name {data_args.audio_column_name} not found in dataset '{data_args.dataset_name}'. "
|
| 303 |
+
"Make sure to set `--audio_column_name` to the correct audio column - one of "
|
| 304 |
+
f"{', '.join(raw_datasets['train'].column_names)}."
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
if data_args.label_column_name not in raw_datasets["train"].column_names:
|
| 308 |
+
raise ValueError(
|
| 309 |
+
f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. "
|
| 310 |
+
"Make sure to set `--label_column_name` to the correct text column - one of "
|
| 311 |
+
f"{', '.join(raw_datasets['train'].column_names)}."
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# Setting `return_attention_mask=True` is the way to get a correctly masked mean-pooling over
|
| 315 |
+
# transformer outputs in the classifier, but it doesn't always lead to better accuracy
|
| 316 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
| 317 |
+
model_args.feature_extractor_name or model_args.model_name_or_path,
|
| 318 |
+
return_attention_mask=model_args.attention_mask,
|
| 319 |
+
cache_dir=model_args.cache_dir,
|
| 320 |
+
revision=model_args.model_revision,
|
| 321 |
+
token=model_args.token,
|
| 322 |
+
trust_remote_code=model_args.trust_remote_code,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# `datasets` takes care of automatically loading and resampling the audio,
|
| 326 |
+
# so we just need to set the correct target sampling rate.
|
| 327 |
+
raw_datasets = raw_datasets.cast_column(
|
| 328 |
+
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
model_input_name = feature_extractor.model_input_names[0]
|
| 332 |
+
|
| 333 |
+
def train_transforms(batch):
|
| 334 |
+
"""Apply train_transforms across a batch."""
|
| 335 |
+
subsampled_wavs = []
|
| 336 |
+
for audio in batch[data_args.audio_column_name]:
|
| 337 |
+
wav = random_subsample(
|
| 338 |
+
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
|
| 339 |
+
)
|
| 340 |
+
subsampled_wavs.append(wav)
|
| 341 |
+
inputs = feature_extractor(subsampled_wavs, sampling_rate=feature_extractor.sampling_rate)
|
| 342 |
+
output_batch = {model_input_name: inputs.get(model_input_name)}
|
| 343 |
+
output_batch["labels"] = list(batch[data_args.label_column_name])
|
| 344 |
+
|
| 345 |
+
return output_batch
|
| 346 |
+
|
| 347 |
+
def val_transforms(batch):
|
| 348 |
+
"""Apply val_transforms across a batch."""
|
| 349 |
+
wavs = [audio["array"] for audio in batch[data_args.audio_column_name]]
|
| 350 |
+
inputs = feature_extractor(wavs, sampling_rate=feature_extractor.sampling_rate)
|
| 351 |
+
output_batch = {model_input_name: inputs.get(model_input_name)}
|
| 352 |
+
output_batch["labels"] = list(batch[data_args.label_column_name])
|
| 353 |
+
|
| 354 |
+
return output_batch
|
| 355 |
+
|
| 356 |
+
# Prepare label mappings.
|
| 357 |
+
# We'll include these in the model's config to get human readable labels in the Inference API.
|
| 358 |
+
labels = raw_datasets["train"].features[data_args.label_column_name].names
|
| 359 |
+
label2id, id2label = {}, {}
|
| 360 |
+
for i, label in enumerate(labels):
|
| 361 |
+
label2id[label] = str(i)
|
| 362 |
+
id2label[str(i)] = label
|
| 363 |
+
|
| 364 |
+
# Load the accuracy metric from the datasets package
|
| 365 |
+
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
|
| 366 |
+
|
| 367 |
+
# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with
|
| 368 |
+
# `predictions` and `label_ids` fields) and has to return a dictionary string to float.
|
| 369 |
+
def compute_metrics(eval_pred):
|
| 370 |
+
"""Computes accuracy on a batch of predictions"""
|
| 371 |
+
predictions = np.argmax(eval_pred.predictions, axis=1)
|
| 372 |
+
return metric.compute(predictions=predictions, references=eval_pred.label_ids)
|
| 373 |
+
|
| 374 |
+
config = AutoConfig.from_pretrained(
|
| 375 |
+
model_args.config_name or model_args.model_name_or_path,
|
| 376 |
+
num_labels=len(labels),
|
| 377 |
+
label2id=label2id,
|
| 378 |
+
id2label=id2label,
|
| 379 |
+
finetuning_task="audio-classification",
|
| 380 |
+
cache_dir=model_args.cache_dir,
|
| 381 |
+
revision=model_args.model_revision,
|
| 382 |
+
token=model_args.token,
|
| 383 |
+
trust_remote_code=model_args.trust_remote_code,
|
| 384 |
+
)
|
| 385 |
+
model = AutoModelForAudioClassification.from_pretrained(
|
| 386 |
+
model_args.model_name_or_path,
|
| 387 |
+
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
| 388 |
+
config=config,
|
| 389 |
+
cache_dir=model_args.cache_dir,
|
| 390 |
+
revision=model_args.model_revision,
|
| 391 |
+
token=model_args.token,
|
| 392 |
+
trust_remote_code=model_args.trust_remote_code,
|
| 393 |
+
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
# freeze the convolutional waveform encoder
|
| 397 |
+
if model_args.freeze_feature_encoder:
|
| 398 |
+
model.freeze_feature_encoder()
|
| 399 |
+
|
| 400 |
+
if model_args.apply_trp:
|
| 401 |
+
model = apply_trp(model, depths=model_args.trp_depths, p=model_args.trp_p, lambdas=model_args.trp_lambdas)
|
| 402 |
+
|
| 403 |
+
if training_args.do_train:
|
| 404 |
+
if data_args.max_train_samples is not None:
|
| 405 |
+
raw_datasets["train"] = (
|
| 406 |
+
raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
|
| 407 |
+
)
|
| 408 |
+
# Set the training transforms
|
| 409 |
+
raw_datasets["train"].set_transform(train_transforms, output_all_columns=False)
|
| 410 |
+
|
| 411 |
+
if training_args.do_eval:
|
| 412 |
+
if data_args.max_eval_samples is not None:
|
| 413 |
+
raw_datasets["eval"] = (
|
| 414 |
+
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
|
| 415 |
+
)
|
| 416 |
+
# Set the validation transforms
|
| 417 |
+
raw_datasets["eval"].set_transform(val_transforms, output_all_columns=False)
|
| 418 |
+
|
| 419 |
+
# Initialize our trainer
|
| 420 |
+
trainer = Trainer(
|
| 421 |
+
model=model,
|
| 422 |
+
args=training_args,
|
| 423 |
+
train_dataset=raw_datasets["train"] if training_args.do_train else None,
|
| 424 |
+
eval_dataset=raw_datasets["eval"] if training_args.do_eval else None,
|
| 425 |
+
compute_metrics=compute_metrics,
|
| 426 |
+
processing_class=feature_extractor,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
# Training
|
| 430 |
+
if training_args.do_train:
|
| 431 |
+
checkpoint = None
|
| 432 |
+
if training_args.resume_from_checkpoint is not None:
|
| 433 |
+
checkpoint = training_args.resume_from_checkpoint
|
| 434 |
+
elif last_checkpoint is not None:
|
| 435 |
+
checkpoint = last_checkpoint
|
| 436 |
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
| 437 |
+
trainer.save_model()
|
| 438 |
+
trainer.log_metrics("train", train_result.metrics)
|
| 439 |
+
trainer.save_metrics("train", train_result.metrics)
|
| 440 |
+
trainer.save_state()
|
| 441 |
+
|
| 442 |
+
# Evaluation
|
| 443 |
+
if training_args.do_eval:
|
| 444 |
+
metrics = trainer.evaluate()
|
| 445 |
+
trainer.log_metrics("eval", metrics)
|
| 446 |
+
trainer.save_metrics("eval", metrics)
|
| 447 |
+
|
| 448 |
+
# Write model card and (optionally) push to hub
|
| 449 |
+
kwargs = {
|
| 450 |
+
"finetuned_from": model_args.model_name_or_path,
|
| 451 |
+
"tasks": "audio-classification",
|
| 452 |
+
"dataset": data_args.dataset_name,
|
| 453 |
+
"tags": ["audio-classification"],
|
| 454 |
+
}
|
| 455 |
+
if training_args.push_to_hub:
|
| 456 |
+
trainer.push_to_hub(**kwargs)
|
| 457 |
+
else:
|
| 458 |
+
trainer.create_model_card(**kwargs)
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
if __name__ == "__main__":
|
| 462 |
+
main()
|
hpo-examples/audio-classification/trplib.py
ADDED
|
@@ -0,0 +1,1181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn, Tensor
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
from torchvision.models.mobilenetv2 import MobileNetV2
|
| 6 |
+
from torchvision.models.resnet import ResNet
|
| 7 |
+
from torchvision.models.efficientnet import EfficientNet
|
| 8 |
+
from torchvision.models.vision_transformer import VisionTransformer
|
| 9 |
+
from torchvision.models.segmentation.fcn import FCN
|
| 10 |
+
from torchvision.models.segmentation.deeplabv3 import DeepLabV3
|
| 11 |
+
|
| 12 |
+
import transformers
|
| 13 |
+
from transformers.modeling_outputs import SequenceClassifierOutput, QuestionAnsweringModelOutput, CausalLMOutput, Seq2SeqLMOutput
|
| 14 |
+
|
| 15 |
+
from typing import Optional, Tuple, List, Union, Callable
|
| 16 |
+
from collections import OrderedDict
|
| 17 |
+
import types
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def trp_criterion(trp_blocks: nn.ModuleList, shared_head: Callable, criterion: Callable, lambdas: List[float], hidden_states: Tensor, logits: Tensor, targets: Tensor, loss_normalization=False):
|
| 21 |
+
loss, mask = criterion(logits, targets)
|
| 22 |
+
if loss_normalization:
|
| 23 |
+
coeff = loss.detach()
|
| 24 |
+
|
| 25 |
+
embeds = [hidden_states]
|
| 26 |
+
predictions = []
|
| 27 |
+
for k, c in enumerate(lambdas):
|
| 28 |
+
embeds.append(trp_blocks[k](embeds[-1]))
|
| 29 |
+
predictions.append(shared_head(embeds[-1]))
|
| 30 |
+
replica_loss, mask = criterion(predictions[-1], targets, mask)
|
| 31 |
+
loss += c * replica_loss
|
| 32 |
+
|
| 33 |
+
if loss_normalization:
|
| 34 |
+
with torch.no_grad():
|
| 35 |
+
coeff = torch.exp(coeff) / torch.exp(loss.detach())
|
| 36 |
+
loss = coeff * loss
|
| 37 |
+
|
| 38 |
+
return loss
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TPBlock(nn.Module):
|
| 42 |
+
def __init__(self, depths: int, in_features: int, p: float, dim=-1):
|
| 43 |
+
super(TPBlock, self).__init__()
|
| 44 |
+
|
| 45 |
+
self.dropout = nn.Dropout(p)
|
| 46 |
+
|
| 47 |
+
self.cdim = dim
|
| 48 |
+
|
| 49 |
+
blocks = []
|
| 50 |
+
for _ in range(depths):
|
| 51 |
+
blocks.append(nn.Linear(in_features, in_features))
|
| 52 |
+
nn.init.constant_(blocks[-1].weight, 0.0)
|
| 53 |
+
nn.init.constant_(blocks[-1].bias, 0.0)
|
| 54 |
+
blocks.append(nn.ReLU())
|
| 55 |
+
self.blocks = nn.Sequential(*blocks)
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
x = self.dropout(x)
|
| 59 |
+
if self.cdim == -1:
|
| 60 |
+
x = x + self.blocks(x)
|
| 61 |
+
else:
|
| 62 |
+
x = x + torch.movedim(self.blocks(torch.movedim(x, self.cdim, -1)), -1, self.cdim)
|
| 63 |
+
return x
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class Config:
|
| 67 |
+
@staticmethod
|
| 68 |
+
def gen_criterion(*args, **kwargs):
|
| 69 |
+
def func(input, target, mask=None):
|
| 70 |
+
"""
|
| 71 |
+
Args:
|
| 72 |
+
input (Tensor): Input tensor.
|
| 73 |
+
target (Tensor): Target labels.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
loss (Tensor): Scalar tensor representing the loss.
|
| 77 |
+
mask (Tensor): Boolean mask tensor with the same shape of target.
|
| 78 |
+
"""
|
| 79 |
+
pass
|
| 80 |
+
return func
|
| 81 |
+
|
| 82 |
+
@staticmethod
|
| 83 |
+
def gen_shared_head(*args, **kwargs):
|
| 84 |
+
def func(hidden_states):
|
| 85 |
+
"""
|
| 86 |
+
Args:
|
| 87 |
+
hidden_states (Tensor): Hidden States tensor.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
logits (Tensor): Logits tensor.
|
| 91 |
+
"""
|
| 92 |
+
pass
|
| 93 |
+
return func
|
| 94 |
+
|
| 95 |
+
@staticmethod
|
| 96 |
+
def forward(*args, **kwargs):
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# Wav2Vec2 for Audio Classification
|
| 101 |
+
class Wav2Vec2ForSequenceClassificationConfig(Config):
|
| 102 |
+
_HIDDEN_STATES_START_POSITION = 2
|
| 103 |
+
|
| 104 |
+
@staticmethod
|
| 105 |
+
def gen_criterion():
|
| 106 |
+
def func(input, target, mask=None):
|
| 107 |
+
"""
|
| 108 |
+
Args:
|
| 109 |
+
input (Tensor): Input tensor of shape [B, C].
|
| 110 |
+
target (Tensor): Target labels of shape [B].
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
loss (Tensor): Scalar tensor representing the loss.
|
| 114 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
| 115 |
+
"""
|
| 116 |
+
if mask is None:
|
| 117 |
+
mask = torch.ones_like(target, dtype=torch.float32, device=target.device)
|
| 118 |
+
|
| 119 |
+
unmasked_loss = F.cross_entropy(input, target, reduction="none")
|
| 120 |
+
loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
|
| 121 |
+
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
mask = mask * torch.eq(torch.argmax(input, dim=1), target).to(input.dtype)
|
| 124 |
+
|
| 125 |
+
return loss, mask
|
| 126 |
+
return func
|
| 127 |
+
|
| 128 |
+
@staticmethod
|
| 129 |
+
def gen_shared_head(self, attention_mask):
|
| 130 |
+
def func(hidden_states):
|
| 131 |
+
"""
|
| 132 |
+
Args:
|
| 133 |
+
hidden_states (Tensor): Hidden States of shape [B, L, hidden_units].
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
| 137 |
+
"""
|
| 138 |
+
_hidden_states = self.projector(hidden_states)
|
| 139 |
+
if attention_mask is None:
|
| 140 |
+
pooled_output = _hidden_states.mean(dim=1)
|
| 141 |
+
else:
|
| 142 |
+
padding_mask = self._get_feature_vector_attention_mask(_hidden_states.shape[1], attention_mask)
|
| 143 |
+
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, _hidden_states.shape[2])
|
| 144 |
+
_hidden_states[~expand_padding_mask] = 0.0
|
| 145 |
+
pooled_output = _hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
| 146 |
+
|
| 147 |
+
logits = self.classifier(pooled_output)
|
| 148 |
+
return logits
|
| 149 |
+
return func
|
| 150 |
+
|
| 151 |
+
@staticmethod
|
| 152 |
+
def gen_forward(lambdas, loss_normalization=False):
|
| 153 |
+
def func(
|
| 154 |
+
self,
|
| 155 |
+
input_values: Optional[torch.Tensor],
|
| 156 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 157 |
+
output_attentions: Optional[bool] = None,
|
| 158 |
+
output_hidden_states: Optional[bool] = None,
|
| 159 |
+
return_dict: Optional[bool] = None,
|
| 160 |
+
labels: Optional[torch.Tensor] = None,
|
| 161 |
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
| 162 |
+
r"""
|
| 163 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 164 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 165 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 166 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 170 |
+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
| 171 |
+
|
| 172 |
+
outputs = self.wav2vec2(
|
| 173 |
+
input_values,
|
| 174 |
+
attention_mask=attention_mask,
|
| 175 |
+
output_attentions=output_attentions,
|
| 176 |
+
output_hidden_states=output_hidden_states,
|
| 177 |
+
return_dict=return_dict,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
if self.config.use_weighted_layer_sum:
|
| 181 |
+
hidden_states = outputs[Wav2Vec2ForSequenceClassificationConfig._HIDDEN_STATES_START_POSITION]
|
| 182 |
+
hidden_states = torch.stack(hidden_states, dim=1)
|
| 183 |
+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
| 184 |
+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
| 185 |
+
else:
|
| 186 |
+
hidden_states = outputs[0]
|
| 187 |
+
|
| 188 |
+
_hidden_states = self.projector(hidden_states)
|
| 189 |
+
if attention_mask is None:
|
| 190 |
+
pooled_output = _hidden_states.mean(dim=1)
|
| 191 |
+
else:
|
| 192 |
+
padding_mask = self._get_feature_vector_attention_mask(_hidden_states.shape[1], attention_mask)
|
| 193 |
+
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, _hidden_states.shape[2])
|
| 194 |
+
_hidden_states[~expand_padding_mask] = 0.0
|
| 195 |
+
pooled_output = _hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
| 196 |
+
|
| 197 |
+
logits = self.classifier(pooled_output)
|
| 198 |
+
|
| 199 |
+
loss = None
|
| 200 |
+
if labels is not None:
|
| 201 |
+
shared_head = Wav2Vec2ForSequenceClassificationConfig.gen_shared_head(self, attention_mask)
|
| 202 |
+
criterion = Wav2Vec2ForSequenceClassificationConfig.gen_criterion()
|
| 203 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_states, logits.view(-1, self.config.num_labels), labels.view(-1), loss_normalization) # NOTE: Apply TRP!
|
| 204 |
+
|
| 205 |
+
if not return_dict:
|
| 206 |
+
output = (logits,) + outputs[Wav2Vec2ForSequenceClassificationConfig._HIDDEN_STATES_START_POSITION:]
|
| 207 |
+
return ((loss,) + output) if loss is not None else output
|
| 208 |
+
|
| 209 |
+
return SequenceClassifierOutput(
|
| 210 |
+
loss=loss,
|
| 211 |
+
logits=logits,
|
| 212 |
+
hidden_states=outputs.hidden_states,
|
| 213 |
+
attentions=outputs.attentions,
|
| 214 |
+
)
|
| 215 |
+
return func
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# MobileNetV2 for Image Classification
|
| 219 |
+
class MobileNetV2Config(Config):
|
| 220 |
+
@staticmethod
|
| 221 |
+
def gen_criterion(label_smoothing=0.0, top_k=1):
|
| 222 |
+
def func(input, target, mask=None):
|
| 223 |
+
"""
|
| 224 |
+
Args:
|
| 225 |
+
input (Tensor): Input tensor of shape [B, C].
|
| 226 |
+
target (Tensor): Target labels of shape [B] or [B, C].
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
loss (Tensor): Scalar tensor representing the loss.
|
| 230 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
| 231 |
+
"""
|
| 232 |
+
label = torch.argmax(target, dim=1) if label_smoothing > 0.0 else target
|
| 233 |
+
|
| 234 |
+
unmasked_loss = F.cross_entropy(input, label, reduction="none", label_smoothing=label_smoothing)
|
| 235 |
+
if mask is None:
|
| 236 |
+
mask = torch.ones_like(unmasked_loss, dtype=torch.float32, device=target.device)
|
| 237 |
+
loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
|
| 238 |
+
|
| 239 |
+
with torch.no_grad():
|
| 240 |
+
topk_values, topk_indices = torch.topk(input, top_k, dim=-1)
|
| 241 |
+
mask = mask * torch.eq(topk_indices, label[:, None]).any(dim=-1).to(input.dtype)
|
| 242 |
+
|
| 243 |
+
return loss, mask
|
| 244 |
+
return func
|
| 245 |
+
|
| 246 |
+
@staticmethod
|
| 247 |
+
def gen_shared_head(self):
|
| 248 |
+
def func(x):
|
| 249 |
+
"""
|
| 250 |
+
Args:
|
| 251 |
+
x (Tensor): Hidden States tensor of shape [B, hidden_units].
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
| 255 |
+
"""
|
| 256 |
+
logits = self.classifier(x)
|
| 257 |
+
return logits
|
| 258 |
+
return func
|
| 259 |
+
|
| 260 |
+
@staticmethod
|
| 261 |
+
def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
|
| 262 |
+
def func(self, images: Tensor, targets=None):
|
| 263 |
+
x = self.features(images)
|
| 264 |
+
x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
|
| 265 |
+
x = torch.flatten(x, 1)
|
| 266 |
+
logits = self.classifier(x)
|
| 267 |
+
|
| 268 |
+
if self.training:
|
| 269 |
+
torch._assert(targets is not None, "targets should not be none when in training mode")
|
| 270 |
+
shared_head = MobileNetV2Config.gen_shared_head(self)
|
| 271 |
+
criterion = MobileNetV2Config.gen_criterion(label_smoothing, top_k)
|
| 272 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
|
| 273 |
+
return logits, loss
|
| 274 |
+
return logits
|
| 275 |
+
return func
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# ResNet for Image Classification
|
| 279 |
+
class ResNetConfig(MobileNetV2Config):
|
| 280 |
+
@staticmethod
|
| 281 |
+
def gen_shared_head(self):
|
| 282 |
+
def func(x):
|
| 283 |
+
"""
|
| 284 |
+
Args:
|
| 285 |
+
x (Tensor): Hidden States tensor of shape [B, hidden_units].
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
| 289 |
+
"""
|
| 290 |
+
logits = self.fc(x)
|
| 291 |
+
return logits
|
| 292 |
+
return func
|
| 293 |
+
|
| 294 |
+
@staticmethod
|
| 295 |
+
def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
|
| 296 |
+
def func(self, images: Tensor, targets=None):
|
| 297 |
+
x = self.conv1(images)
|
| 298 |
+
x = self.bn1(x)
|
| 299 |
+
x = self.relu(x)
|
| 300 |
+
x = self.maxpool(x)
|
| 301 |
+
|
| 302 |
+
x = self.layer1(x)
|
| 303 |
+
x = self.layer2(x)
|
| 304 |
+
x = self.layer3(x)
|
| 305 |
+
x = self.layer4(x)
|
| 306 |
+
|
| 307 |
+
x = self.avgpool(x)
|
| 308 |
+
x = torch.flatten(x, 1)
|
| 309 |
+
logits = self.fc(x)
|
| 310 |
+
|
| 311 |
+
if self.training:
|
| 312 |
+
torch._assert(targets is not None, "targets should not be none when in training mode")
|
| 313 |
+
shared_head = ResNetConfig.gen_shared_head(self)
|
| 314 |
+
criterion = ResNetConfig.gen_criterion(label_smoothing, top_k)
|
| 315 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
|
| 316 |
+
return logits, loss
|
| 317 |
+
return logits
|
| 318 |
+
return func
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# EfficientNet for Image Classification
|
| 322 |
+
class EfficientNetConfig(MobileNetV2Config):
|
| 323 |
+
@staticmethod
|
| 324 |
+
def gen_shared_head(self):
|
| 325 |
+
def func(x):
|
| 326 |
+
"""
|
| 327 |
+
Args:
|
| 328 |
+
x (Tensor): Hidden States tensor of shape [B, hidden_units].
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
| 332 |
+
"""
|
| 333 |
+
logits = self.classifier(x)
|
| 334 |
+
return logits
|
| 335 |
+
return func
|
| 336 |
+
|
| 337 |
+
@staticmethod
|
| 338 |
+
def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
|
| 339 |
+
def func(self, images: Tensor, targets=None):
|
| 340 |
+
x = self.features(images)
|
| 341 |
+
x = self.avgpool(x)
|
| 342 |
+
x = torch.flatten(x, 1)
|
| 343 |
+
logits = self.classifier(x)
|
| 344 |
+
|
| 345 |
+
if self.training:
|
| 346 |
+
torch._assert(targets is not None, "targets should not be none when in training mode")
|
| 347 |
+
shared_head = EfficientNetConfig.gen_shared_head(self)
|
| 348 |
+
criterion = EfficientNetConfig.gen_criterion(label_smoothing, top_k)
|
| 349 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
|
| 350 |
+
return logits, loss
|
| 351 |
+
return logits
|
| 352 |
+
return func
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
# ViT for Image Classification
|
| 356 |
+
class VisionTransformerConfig(MobileNetV2Config):
|
| 357 |
+
@staticmethod
|
| 358 |
+
def gen_shared_head(self):
|
| 359 |
+
def func(x):
|
| 360 |
+
"""
|
| 361 |
+
Args:
|
| 362 |
+
x (Tensor): Hidden States tensor of shape [B, hidden_units].
|
| 363 |
+
|
| 364 |
+
Returns:
|
| 365 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
| 366 |
+
"""
|
| 367 |
+
logits = self.heads(x)
|
| 368 |
+
return logits
|
| 369 |
+
return func
|
| 370 |
+
|
| 371 |
+
@staticmethod
|
| 372 |
+
def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
|
| 373 |
+
def func(self, images: Tensor, targets=None):
|
| 374 |
+
x = self._process_input(images)
|
| 375 |
+
n = x.shape[0]
|
| 376 |
+
batch_class_token = self.class_token.expand(n, -1, -1)
|
| 377 |
+
x = torch.cat([batch_class_token, x], dim=1)
|
| 378 |
+
x = self.encoder(x)
|
| 379 |
+
x = x[:, 0]
|
| 380 |
+
|
| 381 |
+
logits = self.heads(x)
|
| 382 |
+
|
| 383 |
+
if self.training:
|
| 384 |
+
torch._assert(targets is not None, "targets should not be none when in training mode")
|
| 385 |
+
shared_head = VisionTransformerConfig.gen_shared_head(self)
|
| 386 |
+
criterion = VisionTransformerConfig.gen_criterion(label_smoothing, top_k)
|
| 387 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
|
| 388 |
+
return logits, loss
|
| 389 |
+
return logits
|
| 390 |
+
return func
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
# Bert for Question Answering
|
| 394 |
+
class BertForQuestionAnsweringConfig(Config):
|
| 395 |
+
@staticmethod
|
| 396 |
+
def gen_criterion(top_k=1):
|
| 397 |
+
def func(input, target: List[Tensor], mask=None):
|
| 398 |
+
"""
|
| 399 |
+
Args:
|
| 400 |
+
input (Tensor): Input tensor of shape [B, C, 2].
|
| 401 |
+
target (List[Tensor]):
|
| 402 |
+
Start Positions of shape [B].
|
| 403 |
+
End Positions of shape [B].
|
| 404 |
+
|
| 405 |
+
Returns:
|
| 406 |
+
loss (Tensor): Scalar tensor representing the loss.
|
| 407 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
| 408 |
+
"""
|
| 409 |
+
start_positions, end_positions = target
|
| 410 |
+
|
| 411 |
+
if mask is None:
|
| 412 |
+
mask = torch.ones_like(start_positions, dtype=torch.float32, device=start_positions.device)
|
| 413 |
+
|
| 414 |
+
start_logits, end_logits = input.split(1, dim=-1)
|
| 415 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
| 416 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
| 417 |
+
|
| 418 |
+
# If we are on multi-GPU, split add a dimension
|
| 419 |
+
if len(start_positions.size()) > 1:
|
| 420 |
+
start_positions = start_positions.squeeze(-1)
|
| 421 |
+
if len(end_positions.size()) > 1:
|
| 422 |
+
end_positions = end_positions.squeeze(-1)
|
| 423 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 424 |
+
ignored_index = start_logits.size(1)
|
| 425 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 426 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 427 |
+
|
| 428 |
+
masked_start_losses = F.cross_entropy(start_logits, start_positions, ignore_index=ignored_index, reduction="none")
|
| 429 |
+
start_loss = torch.sum(mask * masked_start_losses) / (torch.sum(mask) + 1e-6)
|
| 430 |
+
masked_end_losses = F.cross_entropy(end_logits, end_positions, ignore_index=ignored_index, reduction="none")
|
| 431 |
+
end_loss = torch.sum(mask * masked_end_losses) / (torch.sum(mask) + 1e-6)
|
| 432 |
+
|
| 433 |
+
with torch.no_grad():
|
| 434 |
+
topk_values, topk_indices = torch.topk(start_logits, top_k, dim=1)
|
| 435 |
+
mask = mask * torch.eq(topk_indices, start_positions[:, None]).any(dim=1).to(start_logits.dtype)
|
| 436 |
+
topk_values, topk_indices = torch.topk(end_logits, top_k, dim=1)
|
| 437 |
+
mask = mask * torch.eq(topk_indices, end_positions[:, None]).any(dim=1).to(end_logits.dtype)
|
| 438 |
+
|
| 439 |
+
return (start_loss + end_loss) / 2, mask
|
| 440 |
+
return func
|
| 441 |
+
|
| 442 |
+
@staticmethod
|
| 443 |
+
def gen_shared_head(self):
|
| 444 |
+
def func(hidden_states):
|
| 445 |
+
"""
|
| 446 |
+
Args:
|
| 447 |
+
hidden_states (Tensor): Hidden States of shape [B, C, hidden_units].
|
| 448 |
+
|
| 449 |
+
Returns:
|
| 450 |
+
logits (Tensor): Logits tensor of shape [B, C, 2].
|
| 451 |
+
"""
|
| 452 |
+
logits = self.qa_outputs(hidden_states)
|
| 453 |
+
return logits
|
| 454 |
+
return func
|
| 455 |
+
|
| 456 |
+
@staticmethod
|
| 457 |
+
def gen_forward(lambdas, loss_normalization=True, top_k=1):
|
| 458 |
+
def func(
|
| 459 |
+
self,
|
| 460 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 461 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 462 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 463 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 464 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 465 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 466 |
+
start_positions: Optional[torch.Tensor] = None,
|
| 467 |
+
end_positions: Optional[torch.Tensor] = None,
|
| 468 |
+
output_attentions: Optional[bool] = None,
|
| 469 |
+
output_hidden_states: Optional[bool] = None,
|
| 470 |
+
return_dict: Optional[bool] = None,
|
| 471 |
+
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
| 472 |
+
r"""
|
| 473 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 474 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 475 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 476 |
+
are not taken into account for computing the loss.
|
| 477 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 478 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 479 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 480 |
+
are not taken into account for computing the loss.
|
| 481 |
+
"""
|
| 482 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 483 |
+
|
| 484 |
+
outputs = self.bert(
|
| 485 |
+
input_ids,
|
| 486 |
+
attention_mask=attention_mask,
|
| 487 |
+
token_type_ids=token_type_ids,
|
| 488 |
+
position_ids=position_ids,
|
| 489 |
+
head_mask=head_mask,
|
| 490 |
+
inputs_embeds=inputs_embeds,
|
| 491 |
+
output_attentions=output_attentions,
|
| 492 |
+
output_hidden_states=output_hidden_states,
|
| 493 |
+
return_dict=return_dict,
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
sequence_output = outputs[0]
|
| 497 |
+
|
| 498 |
+
logits = self.qa_outputs(sequence_output)
|
| 499 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 500 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
| 501 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
| 502 |
+
|
| 503 |
+
total_loss = None
|
| 504 |
+
if start_positions is not None and end_positions is not None:
|
| 505 |
+
shared_head = BertForQuestionAnsweringConfig.gen_shared_head(self)
|
| 506 |
+
criterion = BertForQuestionAnsweringConfig.gen_criterion()
|
| 507 |
+
total_loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, sequence_output, logits, [start_positions, end_positions], loss_normalization) # NOTE: Apply TRP!
|
| 508 |
+
|
| 509 |
+
if not return_dict:
|
| 510 |
+
output = (start_logits, end_logits) + outputs[2:]
|
| 511 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 512 |
+
|
| 513 |
+
return QuestionAnsweringModelOutput(
|
| 514 |
+
loss=total_loss,
|
| 515 |
+
start_logits=start_logits,
|
| 516 |
+
end_logits=end_logits,
|
| 517 |
+
hidden_states=outputs.hidden_states,
|
| 518 |
+
attentions=outputs.attentions,
|
| 519 |
+
)
|
| 520 |
+
return func
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
# FCN for Semantic Segmentation
|
| 524 |
+
class FCNConfig(Config):
|
| 525 |
+
@staticmethod
|
| 526 |
+
def gen_criterion(top_k=1):
|
| 527 |
+
def func(input, target, mask=None):
|
| 528 |
+
"""
|
| 529 |
+
Args:
|
| 530 |
+
input Tensor: input tensor of shape [B, C, H, W].
|
| 531 |
+
target (Tensor): Target labels of shape [B, H, W].
|
| 532 |
+
|
| 533 |
+
Returns:
|
| 534 |
+
loss (Tensor): Scalar tensor representing the loss.
|
| 535 |
+
mask (Tensor): Boolean mask tensor of shape [B, H, W].
|
| 536 |
+
"""
|
| 537 |
+
if mask is None:
|
| 538 |
+
mask = torch.ones_like(target, dtype=torch.float32, device=target.device)
|
| 539 |
+
|
| 540 |
+
masked_loss = F.cross_entropy(input, target, ignore_index=255, reduction="none")
|
| 541 |
+
loss = torch.sum(mask * masked_loss) / (torch.sum(mask) + 1e-6)
|
| 542 |
+
|
| 543 |
+
with torch.no_grad():
|
| 544 |
+
topk_values, topk_indices = torch.topk(input, top_k, dim=1)
|
| 545 |
+
mask = mask * torch.eq(topk_indices, target[:, None, :, :]).any(dim=1).to(input.dtype)
|
| 546 |
+
# mask = mask * torch.eq(torch.argmax(x, dim=1), target).to(x.dtype)
|
| 547 |
+
|
| 548 |
+
return loss, mask
|
| 549 |
+
return func
|
| 550 |
+
|
| 551 |
+
@staticmethod
|
| 552 |
+
def gen_out_shared_head(self, input_shape):
|
| 553 |
+
def func(features):
|
| 554 |
+
"""
|
| 555 |
+
Args:
|
| 556 |
+
features (Tensor): features tensor of shape [B, hidden_units, H, W].
|
| 557 |
+
|
| 558 |
+
Returns:
|
| 559 |
+
result (Tensors): result tensor of shape [B, C, H, W].
|
| 560 |
+
"""
|
| 561 |
+
x = self.classifier(features)
|
| 562 |
+
result = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
| 563 |
+
return result
|
| 564 |
+
return func
|
| 565 |
+
|
| 566 |
+
@staticmethod
|
| 567 |
+
def gen_aux_shared_head(self, input_shape):
|
| 568 |
+
def func(features):
|
| 569 |
+
"""
|
| 570 |
+
Args:
|
| 571 |
+
features (Tensor): features tensor of shape [B, hidden_units, H, W].
|
| 572 |
+
|
| 573 |
+
Returns:
|
| 574 |
+
result (Tensors): result tensor of shape [B, C, H, W].
|
| 575 |
+
"""
|
| 576 |
+
x = self.aux_classifier(features)
|
| 577 |
+
result = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
| 578 |
+
return result
|
| 579 |
+
return func
|
| 580 |
+
|
| 581 |
+
@staticmethod
|
| 582 |
+
def gen_forward(lambdas, loss_normalization=True, top_k=1):
|
| 583 |
+
def func(self, images: Tensor, targets=None):
|
| 584 |
+
input_shape = images.shape[-2:]
|
| 585 |
+
# contract: features is a dict of tensors
|
| 586 |
+
features = self.backbone(images)
|
| 587 |
+
|
| 588 |
+
result = OrderedDict()
|
| 589 |
+
x = features["out"]
|
| 590 |
+
x = self.classifier(x)
|
| 591 |
+
x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
| 592 |
+
result["out"] = x
|
| 593 |
+
|
| 594 |
+
if self.aux_classifier is not None:
|
| 595 |
+
x = features["aux"]
|
| 596 |
+
x = self.aux_classifier(x)
|
| 597 |
+
x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
| 598 |
+
result["aux"] = x
|
| 599 |
+
|
| 600 |
+
if self.training:
|
| 601 |
+
torch._assert(targets is not None, "targets should not be none when in training mode")
|
| 602 |
+
out_shared_head = FCNConfig.gen_out_shared_head(self, input_shape)
|
| 603 |
+
aux_shared_head = FCNConfig.gen_aux_shared_head(self, input_shape)
|
| 604 |
+
criterion = FCNConfig.gen_criterion(top_k)
|
| 605 |
+
out_loss = trp_criterion(self.out_trp_blocks, out_shared_head, criterion, lambdas, features["out"], result["out"], targets, loss_normalization)
|
| 606 |
+
aux_loss = trp_criterion(self.aux_trp_blocks, aux_shared_head, criterion, lambdas, features["aux"], result["aux"], targets, loss_normalization)
|
| 607 |
+
loss = out_loss + 0.5 * aux_loss
|
| 608 |
+
return result, loss
|
| 609 |
+
return result
|
| 610 |
+
return func
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
# DeepLabV3Config for Semantic Segmentation
|
| 614 |
+
class DeepLabV3Config(FCNConfig):
|
| 615 |
+
pass
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
# Bert for Text Classification
|
| 619 |
+
class BertForSequenceClassificationConfig(Config):
|
| 620 |
+
@staticmethod
|
| 621 |
+
def gen_criterion():
|
| 622 |
+
def func(input, target, mask=None):
|
| 623 |
+
"""
|
| 624 |
+
Args:
|
| 625 |
+
input (Tensor): Input tensor of shape [B, C].
|
| 626 |
+
target (Tensor): Target labels of shape [B].
|
| 627 |
+
|
| 628 |
+
Returns:
|
| 629 |
+
loss (Tensor): Scalar tensor representing the loss.
|
| 630 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
| 631 |
+
"""
|
| 632 |
+
if mask is None:
|
| 633 |
+
mask = torch.ones_like(target, dtype=torch.float32, device=target.device)
|
| 634 |
+
|
| 635 |
+
unmasked_loss = F.cross_entropy(input, target, reduction="none")
|
| 636 |
+
loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
|
| 637 |
+
|
| 638 |
+
with torch.no_grad():
|
| 639 |
+
mask = mask * torch.eq(torch.argmax(input, dim=1), target).to(input.dtype)
|
| 640 |
+
|
| 641 |
+
return loss, mask
|
| 642 |
+
return func
|
| 643 |
+
|
| 644 |
+
@staticmethod
|
| 645 |
+
def gen_shared_head(self):
|
| 646 |
+
def func(hidden_states):
|
| 647 |
+
"""
|
| 648 |
+
Args:
|
| 649 |
+
hidden_states (Tensor): Hidden States of shape [B, hidden_units].
|
| 650 |
+
|
| 651 |
+
Returns:
|
| 652 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
| 653 |
+
"""
|
| 654 |
+
logits = self.classifier(hidden_states)
|
| 655 |
+
return logits
|
| 656 |
+
return func
|
| 657 |
+
|
| 658 |
+
@staticmethod
|
| 659 |
+
def gen_forward(lambdas, loss_normalization=False):
|
| 660 |
+
def func(
|
| 661 |
+
self,
|
| 662 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 663 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 664 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 665 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 666 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 667 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 668 |
+
labels: Optional[torch.Tensor] = None,
|
| 669 |
+
output_attentions: Optional[bool] = None,
|
| 670 |
+
output_hidden_states: Optional[bool] = None,
|
| 671 |
+
return_dict: Optional[bool] = None,
|
| 672 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 673 |
+
r"""
|
| 674 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 675 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 676 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 677 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 678 |
+
"""
|
| 679 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 680 |
+
|
| 681 |
+
outputs = self.bert(
|
| 682 |
+
input_ids,
|
| 683 |
+
attention_mask=attention_mask,
|
| 684 |
+
token_type_ids=token_type_ids,
|
| 685 |
+
position_ids=position_ids,
|
| 686 |
+
head_mask=head_mask,
|
| 687 |
+
inputs_embeds=inputs_embeds,
|
| 688 |
+
output_attentions=output_attentions,
|
| 689 |
+
output_hidden_states=output_hidden_states,
|
| 690 |
+
return_dict=return_dict,
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
pooled_output = outputs[1]
|
| 694 |
+
|
| 695 |
+
pooled_output = self.dropout(pooled_output)
|
| 696 |
+
logits = self.classifier(pooled_output)
|
| 697 |
+
|
| 698 |
+
loss = None
|
| 699 |
+
if labels is not None:
|
| 700 |
+
assert self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int) # TODO: remove this
|
| 701 |
+
if self.config.problem_type is None:
|
| 702 |
+
if self.num_labels == 1:
|
| 703 |
+
self.config.problem_type = "regression"
|
| 704 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 705 |
+
self.config.problem_type = "single_label_classification"
|
| 706 |
+
else:
|
| 707 |
+
self.config.problem_type = "multi_label_classification"
|
| 708 |
+
|
| 709 |
+
if self.config.problem_type == "regression":
|
| 710 |
+
if self.num_labels == 1:
|
| 711 |
+
loss = F.mse_loss(logits.squeeze(), labels.squeeze())
|
| 712 |
+
else:
|
| 713 |
+
loss = F.mse_loss(logits, labels)
|
| 714 |
+
elif self.config.problem_type == "single_label_classification":
|
| 715 |
+
shared_head = BertForSequenceClassificationConfig.gen_shared_head(self)
|
| 716 |
+
criterion = BertForSequenceClassificationConfig.gen_criterion()
|
| 717 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, pooled_output, logits, labels, loss_normalization)
|
| 718 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 719 |
+
loss = F.binary_cross_entropy_with_logits(logits, labels)
|
| 720 |
+
if not return_dict:
|
| 721 |
+
output = (logits,) + outputs[2:]
|
| 722 |
+
return ((loss,) + output) if loss is not None else output
|
| 723 |
+
|
| 724 |
+
return SequenceClassifierOutput(
|
| 725 |
+
loss=loss,
|
| 726 |
+
logits=logits,
|
| 727 |
+
hidden_states=outputs.hidden_states,
|
| 728 |
+
attentions=outputs.attentions,
|
| 729 |
+
)
|
| 730 |
+
return func
|
| 731 |
+
|
| 732 |
+
|
| 733 |
+
# Boberta for Text Classification
|
| 734 |
+
class RobertaForSequenceClassificationConfig(BertForSequenceClassificationConfig):
|
| 735 |
+
@staticmethod
|
| 736 |
+
def gen_shared_head(self):
|
| 737 |
+
def func(hidden_states):
|
| 738 |
+
"""
|
| 739 |
+
Args:
|
| 740 |
+
hidden_states (Tensor): Hidden States of shape [B, hidden_units].
|
| 741 |
+
|
| 742 |
+
Returns:
|
| 743 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
| 744 |
+
"""
|
| 745 |
+
logits = self.classifier(hidden_states)
|
| 746 |
+
return logits
|
| 747 |
+
return func
|
| 748 |
+
|
| 749 |
+
@staticmethod
|
| 750 |
+
def gen_forward(lambdas, loss_normalization=False):
|
| 751 |
+
def func(
|
| 752 |
+
self,
|
| 753 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 754 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 755 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 756 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 757 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 758 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 759 |
+
labels: Optional[torch.LongTensor] = None,
|
| 760 |
+
output_attentions: Optional[bool] = None,
|
| 761 |
+
output_hidden_states: Optional[bool] = None,
|
| 762 |
+
return_dict: Optional[bool] = None,
|
| 763 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 764 |
+
r"""
|
| 765 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 766 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 767 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 768 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 769 |
+
"""
|
| 770 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 771 |
+
|
| 772 |
+
outputs = self.roberta(
|
| 773 |
+
input_ids,
|
| 774 |
+
attention_mask=attention_mask,
|
| 775 |
+
token_type_ids=token_type_ids,
|
| 776 |
+
position_ids=position_ids,
|
| 777 |
+
head_mask=head_mask,
|
| 778 |
+
inputs_embeds=inputs_embeds,
|
| 779 |
+
output_attentions=output_attentions,
|
| 780 |
+
output_hidden_states=output_hidden_states,
|
| 781 |
+
return_dict=return_dict,
|
| 782 |
+
)
|
| 783 |
+
sequence_output = outputs[0]
|
| 784 |
+
logits = self.classifier(sequence_output)
|
| 785 |
+
|
| 786 |
+
loss = None
|
| 787 |
+
if labels is not None:
|
| 788 |
+
assert self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int) # TODO: remove this
|
| 789 |
+
# move labels to correct device to enable model parallelism
|
| 790 |
+
labels = labels.to(logits.device)
|
| 791 |
+
if self.config.problem_type is None:
|
| 792 |
+
if self.num_labels == 1:
|
| 793 |
+
self.config.problem_type = "regression"
|
| 794 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 795 |
+
self.config.problem_type = "single_label_classification"
|
| 796 |
+
else:
|
| 797 |
+
self.config.problem_type = "multi_label_classification"
|
| 798 |
+
|
| 799 |
+
if self.config.problem_type == "regression":
|
| 800 |
+
if self.num_labels == 1:
|
| 801 |
+
loss = F.mse_loss(logits.squeeze(), labels.squeeze())
|
| 802 |
+
else:
|
| 803 |
+
loss = F.mse_loss(logits, labels)
|
| 804 |
+
elif self.config.problem_type == "single_label_classification":
|
| 805 |
+
shared_head = BertForSequenceClassificationConfig.gen_shared_head(self)
|
| 806 |
+
criterion = BertForSequenceClassificationConfig.gen_criterion()
|
| 807 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, sequence_output, logits, labels, loss_normalization)
|
| 808 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 809 |
+
loss = F.binary_cross_entropy_with_logits(logits, labels)
|
| 810 |
+
|
| 811 |
+
if not return_dict:
|
| 812 |
+
output = (logits,) + outputs[2:]
|
| 813 |
+
return ((loss,) + output) if loss is not None else output
|
| 814 |
+
|
| 815 |
+
return SequenceClassifierOutput(
|
| 816 |
+
loss=loss,
|
| 817 |
+
logits=logits,
|
| 818 |
+
hidden_states=outputs.hidden_states,
|
| 819 |
+
attentions=outputs.attentions,
|
| 820 |
+
)
|
| 821 |
+
return func
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
# Wav2Vec2 for Speech Recognition
|
| 825 |
+
class Wav2Vec2ForCTCConfig(Config):
|
| 826 |
+
_HIDDEN_STATES_START_POSITION = 2
|
| 827 |
+
|
| 828 |
+
@staticmethod
|
| 829 |
+
def greedy_decode_ctc(
|
| 830 |
+
log_probs: torch.Tensor,
|
| 831 |
+
input_lengths: torch.Tensor,
|
| 832 |
+
blank_token_id: int,
|
| 833 |
+
target_lengths: torch.Tensor
|
| 834 |
+
):
|
| 835 |
+
"""
|
| 836 |
+
Convert logits to flattened predictions that match the shape of flattened_targets.
|
| 837 |
+
|
| 838 |
+
Args:
|
| 839 |
+
log_probs: [B, L, V] - log-softmax output
|
| 840 |
+
input_lengths: [B] - actual length of each input
|
| 841 |
+
blank_token_id: int - index of blank token
|
| 842 |
+
target_lengths: [B] - used to determine how many predictions to keep per sample
|
| 843 |
+
|
| 844 |
+
Returns:
|
| 845 |
+
flattened_predictions: 1D tensor, same total length as sum(target_lengths)
|
| 846 |
+
"""
|
| 847 |
+
batch_size = log_probs.size(0)
|
| 848 |
+
decoded_all = []
|
| 849 |
+
|
| 850 |
+
predicted_ids = log_probs.argmax(dim=-1) # [B, L]
|
| 851 |
+
|
| 852 |
+
for i in range(batch_size):
|
| 853 |
+
pred = predicted_ids[i][:input_lengths[i]] # [Li]
|
| 854 |
+
prev = None
|
| 855 |
+
decoded = []
|
| 856 |
+
for token in pred:
|
| 857 |
+
token = token.item()
|
| 858 |
+
if token != blank_token_id and token != prev:
|
| 859 |
+
decoded.append(token)
|
| 860 |
+
prev = token
|
| 861 |
+
# Trim or pad to match target_lengths[i]
|
| 862 |
+
tgt_len = target_lengths[i].item()
|
| 863 |
+
if len(decoded) >= tgt_len:
|
| 864 |
+
decoded = decoded[:tgt_len]
|
| 865 |
+
else:
|
| 866 |
+
decoded = decoded + [blank_token_id] * (tgt_len - len(decoded)) # pad with blank
|
| 867 |
+
decoded_all.extend(decoded)
|
| 868 |
+
|
| 869 |
+
return torch.tensor(decoded_all, dtype=torch.long, device=log_probs.device) # shape: [sum(target_lengths)]
|
| 870 |
+
|
| 871 |
+
@staticmethod
|
| 872 |
+
def gen_criterion(input_lengths: Tensor, pad_token_id: int, ctc_zero_infinity: bool):
|
| 873 |
+
def func(logits: Tensor, labels: Tensor, mask=None):
|
| 874 |
+
"""
|
| 875 |
+
Args:
|
| 876 |
+
logits (Tensor): Log Probablities of shape [B, L, V].
|
| 877 |
+
labels (Tensor): Flattened Targets of shape [B, L'].
|
| 878 |
+
|
| 879 |
+
Returns:
|
| 880 |
+
loss (Tensor): Scalar tensor representing the loss.
|
| 881 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
| 882 |
+
"""
|
| 883 |
+
if mask is None:
|
| 884 |
+
mask = torch.ones_like(input_lengths, dtype=torch.float32, device=input_lengths.device)
|
| 885 |
+
|
| 886 |
+
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
| 887 |
+
labels_mask = labels >= 0
|
| 888 |
+
target_lengths = labels_mask.sum(-1)
|
| 889 |
+
flattened_targets = labels.masked_select(labels_mask)
|
| 890 |
+
with torch.backends.cudnn.flags(enabled=False):
|
| 891 |
+
masked_losses = nn.functional.ctc_loss(log_probs, flattened_targets, input_lengths, target_lengths, blank=pad_token_id, reduction="none", zero_infinity=ctc_zero_infinity)
|
| 892 |
+
loss = torch.sum(mask * masked_losses) / (torch.sum(mask) + 1e-6)
|
| 893 |
+
|
| 894 |
+
with torch.no_grad():
|
| 895 |
+
thres = 0.5
|
| 896 |
+
flattened_predictions = Wav2Vec2ForCTCConfig.greedy_decode_ctc(
|
| 897 |
+
log_probs.transpose(0, 1), # [B, T, V]
|
| 898 |
+
input_lengths=input_lengths,
|
| 899 |
+
blank_token_id=pad_token_id,
|
| 900 |
+
target_lengths=target_lengths
|
| 901 |
+
)
|
| 902 |
+
token_wise_mask = torch.eq(flattened_predictions, flattened_targets).to(flattened_targets.dtype)
|
| 903 |
+
segment_ids = torch.arange(len(target_lengths), device=target_lengths.device).repeat_interleave(target_lengths)
|
| 904 |
+
sequence_wise_mask = torch.zeros(len(target_lengths), dtype=target_lengths.dtype, device=token_wise_mask.device).scatter_add(0, segment_ids, token_wise_mask)
|
| 905 |
+
mask = mask * torch.ge(sequence_wise_mask, thres * target_lengths).to(flattened_targets.dtype)
|
| 906 |
+
|
| 907 |
+
return loss, mask
|
| 908 |
+
return func
|
| 909 |
+
|
| 910 |
+
@staticmethod
|
| 911 |
+
def gen_shared_head(self):
|
| 912 |
+
def func(hidden_states):
|
| 913 |
+
"""
|
| 914 |
+
Args:
|
| 915 |
+
hidden_states (Tensor): Hidden States of shape [B, C, hidden_units].
|
| 916 |
+
|
| 917 |
+
Returns:
|
| 918 |
+
logits (Tensor): Logits tensor of shape [B, C, 2].
|
| 919 |
+
"""
|
| 920 |
+
logits = self.lm_head(hidden_states)
|
| 921 |
+
# log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
| 922 |
+
return logits
|
| 923 |
+
return func
|
| 924 |
+
|
| 925 |
+
@staticmethod
|
| 926 |
+
def gen_forward(lambdas, loss_normalization=False):
|
| 927 |
+
def func(
|
| 928 |
+
self,
|
| 929 |
+
input_values: Optional[torch.Tensor],
|
| 930 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 931 |
+
output_attentions: Optional[bool] = None,
|
| 932 |
+
output_hidden_states: Optional[bool] = None,
|
| 933 |
+
return_dict: Optional[bool] = None,
|
| 934 |
+
labels: Optional[torch.Tensor] = None,
|
| 935 |
+
) -> Union[Tuple, CausalLMOutput]:
|
| 936 |
+
r"""
|
| 937 |
+
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
|
| 938 |
+
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
|
| 939 |
+
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
|
| 940 |
+
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
| 941 |
+
config.vocab_size - 1]`.
|
| 942 |
+
"""
|
| 943 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 944 |
+
|
| 945 |
+
if labels is not None and labels.max() >= self.config.vocab_size:
|
| 946 |
+
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
| 947 |
+
|
| 948 |
+
outputs = self.wav2vec2(
|
| 949 |
+
input_values,
|
| 950 |
+
attention_mask=attention_mask,
|
| 951 |
+
output_attentions=output_attentions,
|
| 952 |
+
output_hidden_states=output_hidden_states,
|
| 953 |
+
return_dict=return_dict,
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
hidden_states = outputs[0]
|
| 957 |
+
hidden_states = self.dropout(hidden_states)
|
| 958 |
+
|
| 959 |
+
logits = self.lm_head(hidden_states)
|
| 960 |
+
|
| 961 |
+
loss = None
|
| 962 |
+
if labels is not None:
|
| 963 |
+
# retrieve loss input_lengths from attention_mask
|
| 964 |
+
attention_mask = (
|
| 965 |
+
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
| 966 |
+
)
|
| 967 |
+
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
| 968 |
+
shared_head = Wav2Vec2ForCTCConfig.gen_shared_head(self)
|
| 969 |
+
criterion = Wav2Vec2ForCTCConfig.gen_criterion(input_lengths, self.config.pad_token_id, self.config.ctc_zero_infinity)
|
| 970 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_states, logits, labels, loss_normalization) # NOTE: Apply TRP!
|
| 971 |
+
|
| 972 |
+
if not return_dict:
|
| 973 |
+
output = (logits,) + outputs[Wav2Vec2ForCTCConfig._HIDDEN_STATES_START_POSITION:]
|
| 974 |
+
return ((loss,) + output) if loss is not None else output
|
| 975 |
+
|
| 976 |
+
return CausalLMOutput(
|
| 977 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
| 978 |
+
)
|
| 979 |
+
return func
|
| 980 |
+
|
| 981 |
+
|
| 982 |
+
# MBart for Translation
|
| 983 |
+
class MBartForConditionalGenerationConfig(Config):
|
| 984 |
+
@staticmethod
|
| 985 |
+
def gen_criterion(vocab_size: int, top_k=1):
|
| 986 |
+
def func(logits, labels, mask=None):
|
| 987 |
+
"""
|
| 988 |
+
Args:
|
| 989 |
+
logits (Tensor): Logits tensor of shape [B, L, V].
|
| 990 |
+
labels (Tensor): Target labels of shape [B, L].
|
| 991 |
+
|
| 992 |
+
Returns:
|
| 993 |
+
loss (Tensor): Scalar tensor representing the loss.
|
| 994 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
| 995 |
+
"""
|
| 996 |
+
if mask is None:
|
| 997 |
+
mask = torch.ones_like(labels.view(-1), dtype=torch.float32, device=labels.device)
|
| 998 |
+
|
| 999 |
+
masked_losses = F.cross_entropy(logits.view(-1, vocab_size), labels.view(-1), reduction="none")
|
| 1000 |
+
loss = torch.sum(mask * masked_losses) / (torch.sum(mask) + 1e-6)
|
| 1001 |
+
|
| 1002 |
+
with torch.no_grad():
|
| 1003 |
+
topk_values, topk_indices = torch.topk(logits.view(-1, vocab_size), top_k, dim=1)
|
| 1004 |
+
mask = mask * torch.eq(topk_indices, labels.view(-1, 1)).any(dim=1).to(logits.dtype)
|
| 1005 |
+
|
| 1006 |
+
return loss, mask
|
| 1007 |
+
return func
|
| 1008 |
+
|
| 1009 |
+
@staticmethod
|
| 1010 |
+
def gen_shared_head(self):
|
| 1011 |
+
def func(hidden_states):
|
| 1012 |
+
"""
|
| 1013 |
+
Args:
|
| 1014 |
+
hidden_states (Tensor): Hidden States of shape [B, L, hidden_units].
|
| 1015 |
+
|
| 1016 |
+
Returns:
|
| 1017 |
+
logits (Tensor): Logits tensor of shape [B, L].
|
| 1018 |
+
"""
|
| 1019 |
+
logits = self.lm_head(hidden_states) + self.final_logits_bias
|
| 1020 |
+
return logits
|
| 1021 |
+
return func
|
| 1022 |
+
|
| 1023 |
+
@staticmethod
|
| 1024 |
+
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
|
| 1025 |
+
"""
|
| 1026 |
+
Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
|
| 1027 |
+
have a single `decoder_start_token_id` in contrast to other Bart-like models.
|
| 1028 |
+
"""
|
| 1029 |
+
prev_output_tokens = input_ids.clone()
|
| 1030 |
+
|
| 1031 |
+
if pad_token_id is None:
|
| 1032 |
+
raise ValueError("self.model.config.pad_token_id has to be defined.")
|
| 1033 |
+
# replace possible -100 values in labels by `pad_token_id`
|
| 1034 |
+
prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)
|
| 1035 |
+
|
| 1036 |
+
index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
|
| 1037 |
+
decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
|
| 1038 |
+
prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
|
| 1039 |
+
prev_output_tokens[:, 0] = decoder_start_tokens
|
| 1040 |
+
|
| 1041 |
+
return prev_output_tokens
|
| 1042 |
+
|
| 1043 |
+
@staticmethod
|
| 1044 |
+
def gen_forward(lambdas, loss_normalization=False):
|
| 1045 |
+
def func(
|
| 1046 |
+
self,
|
| 1047 |
+
input_ids: torch.LongTensor = None,
|
| 1048 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1049 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
| 1050 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
| 1051 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1052 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
| 1053 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
| 1054 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 1055 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 1056 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1057 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1058 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1059 |
+
use_cache: Optional[bool] = None,
|
| 1060 |
+
output_attentions: Optional[bool] = None,
|
| 1061 |
+
output_hidden_states: Optional[bool] = None,
|
| 1062 |
+
return_dict: Optional[bool] = None,
|
| 1063 |
+
) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
|
| 1064 |
+
r"""
|
| 1065 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1066 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 1067 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 1068 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 1069 |
+
|
| 1070 |
+
Returns:
|
| 1071 |
+
|
| 1072 |
+
"""
|
| 1073 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1074 |
+
|
| 1075 |
+
if labels is not None:
|
| 1076 |
+
# if use_cache:
|
| 1077 |
+
# logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
| 1078 |
+
use_cache = False
|
| 1079 |
+
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
| 1080 |
+
decoder_input_ids = MBartForConditionalGenerationConfig.shift_tokens_right(labels, self.config.pad_token_id)
|
| 1081 |
+
|
| 1082 |
+
outputs = self.model(
|
| 1083 |
+
input_ids,
|
| 1084 |
+
attention_mask=attention_mask,
|
| 1085 |
+
decoder_input_ids=decoder_input_ids,
|
| 1086 |
+
encoder_outputs=encoder_outputs,
|
| 1087 |
+
decoder_attention_mask=decoder_attention_mask,
|
| 1088 |
+
head_mask=head_mask,
|
| 1089 |
+
decoder_head_mask=decoder_head_mask,
|
| 1090 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
| 1091 |
+
past_key_values=past_key_values,
|
| 1092 |
+
inputs_embeds=inputs_embeds,
|
| 1093 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
| 1094 |
+
use_cache=use_cache,
|
| 1095 |
+
output_attentions=output_attentions,
|
| 1096 |
+
output_hidden_states=output_hidden_states,
|
| 1097 |
+
return_dict=return_dict,
|
| 1098 |
+
)
|
| 1099 |
+
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
|
| 1100 |
+
|
| 1101 |
+
masked_lm_loss = None
|
| 1102 |
+
if labels is not None:
|
| 1103 |
+
shared_head = MBartForConditionalGenerationConfig.gen_shared_head(self)
|
| 1104 |
+
criterion = MBartForConditionalGenerationConfig.gen_criterion(self.config.vocab_size)
|
| 1105 |
+
masked_lm_loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, outputs[0], lm_logits, labels, loss_normalization)
|
| 1106 |
+
|
| 1107 |
+
if not return_dict:
|
| 1108 |
+
output = (lm_logits,) + outputs[1:]
|
| 1109 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 1110 |
+
|
| 1111 |
+
return Seq2SeqLMOutput(
|
| 1112 |
+
loss=masked_lm_loss,
|
| 1113 |
+
logits=lm_logits,
|
| 1114 |
+
past_key_values=outputs.past_key_values,
|
| 1115 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
| 1116 |
+
decoder_attentions=outputs.decoder_attentions,
|
| 1117 |
+
cross_attentions=outputs.cross_attentions,
|
| 1118 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
| 1119 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
| 1120 |
+
encoder_attentions=outputs.encoder_attentions,
|
| 1121 |
+
)
|
| 1122 |
+
return func
|
| 1123 |
+
|
| 1124 |
+
|
| 1125 |
+
def apply_trp(model, depths: int, p: float, lambdas: List[float], **kwargs):
|
| 1126 |
+
if isinstance(model, transformers.Wav2Vec2ForSequenceClassification):
|
| 1127 |
+
print("✅ Applying TRP to Wav2Vec2 for Audio Classification...")
|
| 1128 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
|
| 1129 |
+
model.forward = types.MethodType(Wav2Vec2ForSequenceClassificationConfig.gen_forward(lambdas, False), model)
|
| 1130 |
+
elif isinstance(model, MobileNetV2):
|
| 1131 |
+
print("✅ Applying TRP to MobileNetV2 for Image Classification...")
|
| 1132 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1280, p) for _ in lambdas])
|
| 1133 |
+
model.forward = types.MethodType(MobileNetV2Config.gen_forward(lambdas, True, label_smoothing=0.0, top_k=1), model)
|
| 1134 |
+
elif isinstance(model, ResNet):
|
| 1135 |
+
print("✅ Applying TRP to ResNet for Image Classification...")
|
| 1136 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 2048, p) for _ in lambdas])
|
| 1137 |
+
model.forward = types.MethodType(ResNetConfig.gen_forward(lambdas, True, label_smoothing=0.0, top_k=1), model)
|
| 1138 |
+
elif isinstance(model, EfficientNet):
|
| 1139 |
+
print("✅ Applying TRP to EfficientNet for Image Classification...")
|
| 1140 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1280, p) for _ in lambdas])
|
| 1141 |
+
model.forward = types.MethodType(EfficientNetConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
|
| 1142 |
+
elif isinstance(model, VisionTransformer):
|
| 1143 |
+
print("✅ Applying TRP to VisionTransformer for Image Classification...")
|
| 1144 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
|
| 1145 |
+
model.forward = types.MethodType(VisionTransformerConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
|
| 1146 |
+
elif isinstance(model, transformers.BertForQuestionAnswering):
|
| 1147 |
+
print("✅ Applying TRP to Bert for Question Answering...")
|
| 1148 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
|
| 1149 |
+
model.forward = types.MethodType(BertForQuestionAnsweringConfig.gen_forward(lambdas, True, 1), model)
|
| 1150 |
+
elif isinstance(model, FCN):
|
| 1151 |
+
print("✅ Applying TRP to FCN for Semantic Segmentation...")
|
| 1152 |
+
model.out_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 2048, p, dim=1) for _ in lambdas])
|
| 1153 |
+
model.aux_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p, dim=1) for _ in lambdas])
|
| 1154 |
+
model.forward = types.MethodType(FCNConfig.gen_forward(lambdas, True, 1), model)
|
| 1155 |
+
elif isinstance(model, DeepLabV3):
|
| 1156 |
+
print("✅ Applying TRP to DeepLabV3 for Semantic Segmentation...")
|
| 1157 |
+
model.out_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 2048, p, dim=1) for _ in lambdas])
|
| 1158 |
+
model.aux_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p, dim=1) for _ in lambdas])
|
| 1159 |
+
model.forward = types.MethodType(DeepLabV3Config.gen_forward(lambdas, True, 1), model)
|
| 1160 |
+
elif isinstance(model, transformers.BertForSequenceClassification):
|
| 1161 |
+
print("✅ Applying TRP to Bert for Text Classification...")
|
| 1162 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
|
| 1163 |
+
model.forward = types.MethodType(BertForSequenceClassificationConfig.gen_forward(lambdas, False), model)
|
| 1164 |
+
elif isinstance(model, transformers.RobertaForSequenceClassification):
|
| 1165 |
+
print("✅ Applying TRP to Roberta for Text Classification...")
|
| 1166 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
|
| 1167 |
+
model.forward = types.MethodType(RobertaForSequenceClassificationConfig.gen_forward(lambdas, False), model)
|
| 1168 |
+
elif isinstance(model, transformers.Wav2Vec2ForCTC):
|
| 1169 |
+
print("✅ Applying TRP to Wav2Vec2 for Speech Recognition...")
|
| 1170 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p) for _ in lambdas])
|
| 1171 |
+
model.forward = types.MethodType(Wav2Vec2ForCTCConfig.gen_forward(lambdas, False), model)
|
| 1172 |
+
elif isinstance(model, transformers.MBartForConditionalGeneration):
|
| 1173 |
+
print("✅ Applying TRP to MBart for Translation...")
|
| 1174 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p) for _ in lambdas])
|
| 1175 |
+
model.forward = types.MethodType(MBartForConditionalGenerationConfig.gen_forward(lambdas, False), model)
|
| 1176 |
+
else:
|
| 1177 |
+
torch._assert(
|
| 1178 |
+
isinstance(model, transformers.Wav2Vec2ForSequenceClassification),
|
| 1179 |
+
"The model should be an object of [`Wav2Vec2ForSequenceClassification`].")
|
| 1180 |
+
|
| 1181 |
+
return model
|
hpo-examples/image-classification/__pycache__/presets.cpython-310.pyc
ADDED
|
Binary file (2.31 kB). View file
|
|
|
hpo-examples/image-classification/__pycache__/sampler.cpython-310.pyc
ADDED
|
Binary file (2.41 kB). View file
|
|
|
hpo-examples/image-classification/__pycache__/transforms.cpython-310.pyc
ADDED
|
Binary file (5.29 kB). View file
|
|
|
hpo-examples/image-classification/__pycache__/trplib.cpython-310.pyc
ADDED
|
Binary file (37.5 kB). View file
|
|
|
hpo-examples/image-classification/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
hpo-examples/image-classification/efficientnet_v2_m/model_7.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5f4b45a082517a3e60498e92710f5a97b5869515db52e536d9c92a3b68ae4e8f
|
| 3 |
+
size 454515355
|
hpo-examples/image-classification/mobilenetv2/model_32.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f240e3b954e1b1786878733c3045125461fa67bcb6ce50a200d5ec6e46081bc7
|
| 3 |
+
size 48002008
|
hpo-examples/image-classification/presets.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision.transforms import autoaugment, transforms
|
| 3 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ClassificationPresetTrain:
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
*,
|
| 10 |
+
crop_size,
|
| 11 |
+
mean=(0.485, 0.456, 0.406),
|
| 12 |
+
std=(0.229, 0.224, 0.225),
|
| 13 |
+
interpolation=InterpolationMode.BILINEAR,
|
| 14 |
+
hflip_prob=0.5,
|
| 15 |
+
auto_augment_policy=None,
|
| 16 |
+
ra_magnitude=9,
|
| 17 |
+
augmix_severity=3,
|
| 18 |
+
random_erase_prob=0.0,
|
| 19 |
+
):
|
| 20 |
+
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
|
| 21 |
+
if hflip_prob > 0:
|
| 22 |
+
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
|
| 23 |
+
if auto_augment_policy is not None:
|
| 24 |
+
if auto_augment_policy == "ra":
|
| 25 |
+
trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
|
| 26 |
+
elif auto_augment_policy == "ta_wide":
|
| 27 |
+
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
|
| 28 |
+
elif auto_augment_policy == "augmix":
|
| 29 |
+
trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity))
|
| 30 |
+
else:
|
| 31 |
+
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
|
| 32 |
+
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
|
| 33 |
+
trans.extend(
|
| 34 |
+
[
|
| 35 |
+
transforms.PILToTensor(),
|
| 36 |
+
transforms.ConvertImageDtype(torch.float),
|
| 37 |
+
transforms.Normalize(mean=mean, std=std),
|
| 38 |
+
]
|
| 39 |
+
)
|
| 40 |
+
if random_erase_prob > 0:
|
| 41 |
+
trans.append(transforms.RandomErasing(p=random_erase_prob))
|
| 42 |
+
|
| 43 |
+
self.transforms = transforms.Compose(trans)
|
| 44 |
+
|
| 45 |
+
def __call__(self, img):
|
| 46 |
+
return self.transforms(img)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ClassificationPresetEval:
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
*,
|
| 53 |
+
crop_size,
|
| 54 |
+
resize_size=256,
|
| 55 |
+
mean=(0.485, 0.456, 0.406),
|
| 56 |
+
std=(0.229, 0.224, 0.225),
|
| 57 |
+
interpolation=InterpolationMode.BILINEAR,
|
| 58 |
+
):
|
| 59 |
+
|
| 60 |
+
self.transforms = transforms.Compose(
|
| 61 |
+
[
|
| 62 |
+
transforms.Resize(resize_size, interpolation=interpolation),
|
| 63 |
+
transforms.CenterCrop(crop_size),
|
| 64 |
+
transforms.PILToTensor(),
|
| 65 |
+
transforms.ConvertImageDtype(torch.float),
|
| 66 |
+
transforms.Normalize(mean=mean, std=std),
|
| 67 |
+
]
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def __call__(self, img):
|
| 71 |
+
return self.transforms(img)
|
hpo-examples/image-classification/resnet50/model_35.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3b304e1309e3143b9e7109d5b7263369e4e93c949dde9bab44b3d3b193d16361
|
| 3 |
+
size 255177167
|
hpo-examples/image-classification/run.sh
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ✅ --lr 0.00002 Acc@1 71.878 Acc@5 90.286 -> Acc@1 72.104 Acc@5 90.316 (with normalization)
|
| 2 |
+
torchrun --nproc_per_node=4 train.py\
|
| 3 |
+
--data-path /home/cs/Documents/datasets/imagenet\
|
| 4 |
+
--model mobilenet_v2 --output-dir mobilenet_v2 --weights MobileNet_V2_Weights.IMAGENET1K_V1\
|
| 5 |
+
--batch-size 192 --epochs 40 --lr 0.0004 --lr-step-size 10 --lr-gamma 0.5 --wd 0.00004 --apply-trp --trp-depths 1 --trp-p 0.15 --trp-lambdas 0.4 0.2 0.1
|
| 6 |
+
# torchrun --nproc_per_node=4 train.py\
|
| 7 |
+
# --data-path /home/cs/Documents/datasets/imagenet\
|
| 8 |
+
# --model mobilenet_v2 --resume mobilenet_v2/model_32.pth --test-only
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# ✅ --lr 0.0002 Acc@1 76.130 Acc@5 92.862 -> Acc@1 77.234 Acc@5 93.322 (with normalization)
|
| 12 |
+
torchrun --nproc_per_node=4 train.py\
|
| 13 |
+
--data-path /home/cs/Documents/datasets/imagenet\
|
| 14 |
+
--model resnet50 --output-dir resnet50 --weights ResNet50_Weights.IMAGENET1K_V1\
|
| 15 |
+
--batch-size 64 --epochs 40 --lr 0.0004 --lr-step-size 10 --lr-gamma 0.5 --print-freq 100\
|
| 16 |
+
--apply-trp --trp-depths 1 --trp-p 0.2 --trp-lambdas 0.4 0.2 0.1
|
| 17 |
+
# torchrun --nproc_per_node=4 train.py\
|
| 18 |
+
# --data-path /home/cs/Documents/datasets/imagenet\
|
| 19 |
+
# --model resnet50 --resume resnet50/model_35.pth --test-only
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# ✅ Test: Acc@1 85.218 Acc@5 97.208
|
| 23 |
+
torchrun --nproc_per_node=4 train.py \
|
| 24 |
+
--data-path /home/cs/Documents/datasets/imagenet\
|
| 25 |
+
--model efficientnet_v2_m --output-dir efficientnet_v2_m --weights EfficientNet_V2_M_Weights.IMAGENET1K_V1\
|
| 26 |
+
--epochs 10 --batch-size 64 --lr 5e-9 --lr-scheduler cosineannealinglr --weight-decay 0.00002 \
|
| 27 |
+
--lr-warmup-method constant --lr-warmup-epochs 8 --lr-warmup-decay 0. \
|
| 28 |
+
--auto-augment ta_wide --random-erase 0.1 --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --norm-weight-decay 0.0 \
|
| 29 |
+
--train-crop-size 384 --val-crop-size 480 --val-resize-size 480 --ra-sampler --ra-reps 4 --print-freq 100\
|
| 30 |
+
--apply-trp --trp-depths 1 --trp-p 0.2 --trp-lambdas 0.4 0.2 0.1
|
| 31 |
+
# torchrun --nproc_per_node=4 train.py\
|
| 32 |
+
# --data-path /home/cs/Documents/datasets/imagenet\
|
| 33 |
+
# --model efficientnet_v2_m --resume efficientnet_v2_m/model_7.pth --test-only\
|
| 34 |
+
# --val-crop-size 480 --val-resize-size 480
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ✅ Test: Acc@1 81.092 Acc@5 95.304
|
| 38 |
+
torchrun --nproc_per_node=4 train.py\
|
| 39 |
+
--data-path /home/cs/Documents/datasets/imagenet\
|
| 40 |
+
--model vit_b_16 --output-dir vit_b_16 --weights ViT_B_16_Weights.IMAGENET1K_V1\
|
| 41 |
+
--epochs 5 --batch-size 196 --opt adamw --lr 5e-9 --lr-scheduler cosineannealinglr --wd 0.3\
|
| 42 |
+
--lr-warmup-method constant --lr-warmup-epochs 3 --lr-warmup-decay 0. \
|
| 43 |
+
--amp --label-smoothing 0.11 --mixup-alpha 0.2 --auto-augment ra --clip-grad-norm 1 --cutmix-alpha 1.0\
|
| 44 |
+
--apply-trp --trp-depths 1 --trp-p 0.1 --trp-lambdas 0.4 0.2 0.1 --print-freq 100
|
| 45 |
+
# torchrun --nproc_per_node=4 train.py\
|
| 46 |
+
# --data-path /home/cs/Documents/datasets/imagenet\
|
| 47 |
+
# --model vit_b_16 --resume vit_b_16/model_4.pth --test-only
|
| 48 |
+
|
| 49 |
+
|
hpo-examples/image-classification/sampler.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class RASampler(torch.utils.data.Sampler):
|
| 8 |
+
"""Sampler that restricts data loading to a subset of the dataset for distributed,
|
| 9 |
+
with repeated augmentation.
|
| 10 |
+
It ensures that different each augmented version of a sample will be visible to a
|
| 11 |
+
different process (GPU).
|
| 12 |
+
Heavily based on 'torch.utils.data.DistributedSampler'.
|
| 13 |
+
|
| 14 |
+
This is borrowed from the DeiT Repo:
|
| 15 |
+
https://github.com/facebookresearch/deit/blob/main/samplers.py
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
|
| 19 |
+
if num_replicas is None:
|
| 20 |
+
if not dist.is_available():
|
| 21 |
+
raise RuntimeError("Requires distributed package to be available!")
|
| 22 |
+
num_replicas = dist.get_world_size()
|
| 23 |
+
if rank is None:
|
| 24 |
+
if not dist.is_available():
|
| 25 |
+
raise RuntimeError("Requires distributed package to be available!")
|
| 26 |
+
rank = dist.get_rank()
|
| 27 |
+
self.dataset = dataset
|
| 28 |
+
self.num_replicas = num_replicas
|
| 29 |
+
self.rank = rank
|
| 30 |
+
self.epoch = 0
|
| 31 |
+
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
|
| 32 |
+
self.total_size = self.num_samples * self.num_replicas
|
| 33 |
+
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
|
| 34 |
+
self.shuffle = shuffle
|
| 35 |
+
self.seed = seed
|
| 36 |
+
self.repetitions = repetitions
|
| 37 |
+
|
| 38 |
+
def __iter__(self):
|
| 39 |
+
if self.shuffle:
|
| 40 |
+
# Deterministically shuffle based on epoch
|
| 41 |
+
g = torch.Generator()
|
| 42 |
+
g.manual_seed(self.seed + self.epoch)
|
| 43 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
| 44 |
+
else:
|
| 45 |
+
indices = list(range(len(self.dataset)))
|
| 46 |
+
|
| 47 |
+
# Add extra samples to make it evenly divisible
|
| 48 |
+
indices = [ele for ele in indices for i in range(self.repetitions)]
|
| 49 |
+
indices += indices[: (self.total_size - len(indices))]
|
| 50 |
+
assert len(indices) == self.total_size
|
| 51 |
+
|
| 52 |
+
# Subsample
|
| 53 |
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
| 54 |
+
assert len(indices) == self.num_samples
|
| 55 |
+
|
| 56 |
+
return iter(indices[: self.num_selected_samples])
|
| 57 |
+
|
| 58 |
+
def __len__(self):
|
| 59 |
+
return self.num_selected_samples
|
| 60 |
+
|
| 61 |
+
def set_epoch(self, epoch):
|
| 62 |
+
self.epoch = epoch
|
hpo-examples/image-classification/train.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
import warnings
|
| 5 |
+
|
| 6 |
+
import presets
|
| 7 |
+
import torch
|
| 8 |
+
import torch.utils.data
|
| 9 |
+
import torchvision
|
| 10 |
+
import transforms
|
| 11 |
+
import utils
|
| 12 |
+
from sampler import RASampler
|
| 13 |
+
from torch import nn
|
| 14 |
+
from torch.utils.data.dataloader import default_collate
|
| 15 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 16 |
+
|
| 17 |
+
from trplib import apply_trp
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
|
| 21 |
+
model.train()
|
| 22 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
| 23 |
+
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
|
| 24 |
+
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
|
| 25 |
+
|
| 26 |
+
header = f"Epoch: [{epoch}]"
|
| 27 |
+
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
|
| 28 |
+
start_time = time.time()
|
| 29 |
+
image, target = image.to(device), target.to(device)
|
| 30 |
+
with torch.amp.autocast("cuda", enabled=scaler is not None):
|
| 31 |
+
# output = model(image)
|
| 32 |
+
# loss = criterion(output, target)
|
| 33 |
+
output, loss = model(image, target)
|
| 34 |
+
|
| 35 |
+
optimizer.zero_grad()
|
| 36 |
+
if scaler is not None:
|
| 37 |
+
scaler.scale(loss).backward()
|
| 38 |
+
if args.clip_grad_norm is not None:
|
| 39 |
+
# we should unscale the gradients of optimizer's assigned params if do gradient clipping
|
| 40 |
+
scaler.unscale_(optimizer)
|
| 41 |
+
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
|
| 42 |
+
scaler.step(optimizer)
|
| 43 |
+
scaler.update()
|
| 44 |
+
else:
|
| 45 |
+
loss.backward()
|
| 46 |
+
if args.clip_grad_norm is not None:
|
| 47 |
+
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
|
| 48 |
+
optimizer.step()
|
| 49 |
+
|
| 50 |
+
if model_ema and i % args.model_ema_steps == 0:
|
| 51 |
+
model_ema.update_parameters(model)
|
| 52 |
+
if epoch < args.lr_warmup_epochs:
|
| 53 |
+
# Reset ema buffer to keep copying weights during warmup period
|
| 54 |
+
model_ema.n_averaged.fill_(0)
|
| 55 |
+
|
| 56 |
+
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
|
| 57 |
+
batch_size = image.shape[0]
|
| 58 |
+
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
|
| 59 |
+
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
|
| 60 |
+
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
|
| 61 |
+
metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
|
| 65 |
+
model.eval()
|
| 66 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
| 67 |
+
header = f"Test: {log_suffix}"
|
| 68 |
+
|
| 69 |
+
num_processed_samples = 0
|
| 70 |
+
with torch.inference_mode():
|
| 71 |
+
for image, target in metric_logger.log_every(data_loader, print_freq, header):
|
| 72 |
+
image = image.to(device, non_blocking=True)
|
| 73 |
+
target = target.to(device, non_blocking=True)
|
| 74 |
+
output = model(image)
|
| 75 |
+
loss = criterion(output, target)
|
| 76 |
+
|
| 77 |
+
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
|
| 78 |
+
# FIXME need to take into account that the datasets
|
| 79 |
+
# could have been padded in distributed setup
|
| 80 |
+
batch_size = image.shape[0]
|
| 81 |
+
metric_logger.update(loss=loss.item())
|
| 82 |
+
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
|
| 83 |
+
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
|
| 84 |
+
num_processed_samples += batch_size
|
| 85 |
+
# gather the stats from all processes
|
| 86 |
+
|
| 87 |
+
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
|
| 88 |
+
if (
|
| 89 |
+
hasattr(data_loader.dataset, "__len__")
|
| 90 |
+
and len(data_loader.dataset) != num_processed_samples
|
| 91 |
+
and torch.distributed.get_rank() == 0
|
| 92 |
+
):
|
| 93 |
+
# See FIXME above
|
| 94 |
+
warnings.warn(
|
| 95 |
+
f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
|
| 96 |
+
"samples were used for the validation, which might bias the results. "
|
| 97 |
+
"Try adjusting the batch size and / or the world size. "
|
| 98 |
+
"Setting the world size to 1 is always a safe bet."
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
metric_logger.synchronize_between_processes()
|
| 102 |
+
|
| 103 |
+
print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
|
| 104 |
+
return metric_logger.acc1.global_avg
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _get_cache_path(filepath):
|
| 108 |
+
import hashlib
|
| 109 |
+
|
| 110 |
+
h = hashlib.sha1(filepath.encode()).hexdigest()
|
| 111 |
+
cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
|
| 112 |
+
cache_path = os.path.expanduser(cache_path)
|
| 113 |
+
return cache_path
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def load_data(traindir, valdir, args):
|
| 117 |
+
# Data loading code
|
| 118 |
+
print("Loading data")
|
| 119 |
+
val_resize_size, val_crop_size, train_crop_size = (
|
| 120 |
+
args.val_resize_size,
|
| 121 |
+
args.val_crop_size,
|
| 122 |
+
args.train_crop_size,
|
| 123 |
+
)
|
| 124 |
+
interpolation = InterpolationMode(args.interpolation)
|
| 125 |
+
|
| 126 |
+
print("Loading training data")
|
| 127 |
+
st = time.time()
|
| 128 |
+
cache_path = _get_cache_path(traindir)
|
| 129 |
+
if args.cache_dataset and os.path.exists(cache_path):
|
| 130 |
+
# Attention, as the transforms are also cached!
|
| 131 |
+
print(f"Loading dataset_train from {cache_path}")
|
| 132 |
+
dataset, _ = torch.load(cache_path)
|
| 133 |
+
else:
|
| 134 |
+
auto_augment_policy = getattr(args, "auto_augment", None)
|
| 135 |
+
random_erase_prob = getattr(args, "random_erase", 0.0)
|
| 136 |
+
ra_magnitude = args.ra_magnitude
|
| 137 |
+
augmix_severity = args.augmix_severity
|
| 138 |
+
dataset = torchvision.datasets.ImageFolder(
|
| 139 |
+
traindir,
|
| 140 |
+
presets.ClassificationPresetTrain(
|
| 141 |
+
crop_size=train_crop_size,
|
| 142 |
+
interpolation=interpolation,
|
| 143 |
+
auto_augment_policy=auto_augment_policy,
|
| 144 |
+
random_erase_prob=random_erase_prob,
|
| 145 |
+
ra_magnitude=ra_magnitude,
|
| 146 |
+
augmix_severity=augmix_severity,
|
| 147 |
+
),
|
| 148 |
+
)
|
| 149 |
+
if args.cache_dataset:
|
| 150 |
+
print(f"Saving dataset_train to {cache_path}")
|
| 151 |
+
utils.mkdir(os.path.dirname(cache_path))
|
| 152 |
+
utils.save_on_master((dataset, traindir), cache_path)
|
| 153 |
+
print("Took", time.time() - st)
|
| 154 |
+
|
| 155 |
+
print("Loading validation data")
|
| 156 |
+
cache_path = _get_cache_path(valdir)
|
| 157 |
+
if args.cache_dataset and os.path.exists(cache_path):
|
| 158 |
+
# Attention, as the transforms are also cached!
|
| 159 |
+
print(f"Loading dataset_test from {cache_path}")
|
| 160 |
+
dataset_test, _ = torch.load(cache_path)
|
| 161 |
+
else:
|
| 162 |
+
if args.weights and args.test_only:
|
| 163 |
+
weights = torchvision.models.get_weight(args.weights)
|
| 164 |
+
preprocessing = weights.transforms()
|
| 165 |
+
else:
|
| 166 |
+
preprocessing = presets.ClassificationPresetEval(
|
| 167 |
+
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
dataset_test = torchvision.datasets.ImageFolder(
|
| 171 |
+
valdir,
|
| 172 |
+
preprocessing,
|
| 173 |
+
)
|
| 174 |
+
if args.cache_dataset:
|
| 175 |
+
print(f"Saving dataset_test to {cache_path}")
|
| 176 |
+
utils.mkdir(os.path.dirname(cache_path))
|
| 177 |
+
utils.save_on_master((dataset_test, valdir), cache_path)
|
| 178 |
+
|
| 179 |
+
print("Creating data loaders")
|
| 180 |
+
if args.distributed:
|
| 181 |
+
if hasattr(args, "ra_sampler") and args.ra_sampler:
|
| 182 |
+
train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
|
| 183 |
+
else:
|
| 184 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
| 185 |
+
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
|
| 186 |
+
else:
|
| 187 |
+
train_sampler = torch.utils.data.RandomSampler(dataset)
|
| 188 |
+
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
| 189 |
+
|
| 190 |
+
return dataset, dataset_test, train_sampler, test_sampler
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def main(args):
|
| 194 |
+
if args.output_dir:
|
| 195 |
+
utils.mkdir(args.output_dir)
|
| 196 |
+
|
| 197 |
+
utils.init_distributed_mode(args)
|
| 198 |
+
print(args)
|
| 199 |
+
|
| 200 |
+
device = torch.device(args.device)
|
| 201 |
+
|
| 202 |
+
if args.use_deterministic_algorithms:
|
| 203 |
+
torch.backends.cudnn.benchmark = False
|
| 204 |
+
torch.use_deterministic_algorithms(True)
|
| 205 |
+
else:
|
| 206 |
+
torch.backends.cudnn.benchmark = True
|
| 207 |
+
|
| 208 |
+
train_dir = os.path.join(args.data_path, "train")
|
| 209 |
+
val_dir = os.path.join(args.data_path, "val")
|
| 210 |
+
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
|
| 211 |
+
|
| 212 |
+
collate_fn = None
|
| 213 |
+
num_classes = len(dataset.classes)
|
| 214 |
+
mixup_transforms = []
|
| 215 |
+
if args.mixup_alpha > 0.0:
|
| 216 |
+
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
|
| 217 |
+
if args.cutmix_alpha > 0.0:
|
| 218 |
+
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
|
| 219 |
+
if mixup_transforms:
|
| 220 |
+
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
|
| 221 |
+
|
| 222 |
+
def collate_fn(batch):
|
| 223 |
+
return mixupcutmix(*default_collate(batch))
|
| 224 |
+
|
| 225 |
+
data_loader = torch.utils.data.DataLoader(
|
| 226 |
+
dataset,
|
| 227 |
+
batch_size=args.batch_size,
|
| 228 |
+
sampler=train_sampler,
|
| 229 |
+
num_workers=args.workers,
|
| 230 |
+
pin_memory=True,
|
| 231 |
+
collate_fn=collate_fn,
|
| 232 |
+
)
|
| 233 |
+
data_loader_test = torch.utils.data.DataLoader(
|
| 234 |
+
dataset_test, batch_size=8, sampler=test_sampler, num_workers=args.workers, pin_memory=True
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
print("Creating model")
|
| 238 |
+
model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
|
| 239 |
+
if args.apply_trp:
|
| 240 |
+
model = apply_trp(model, args.trp_depths, args.trp_p, args.trp_lambdas, label_smoothing=args.label_smoothing)
|
| 241 |
+
model.to(device)
|
| 242 |
+
|
| 243 |
+
if args.distributed and args.sync_bn:
|
| 244 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
| 245 |
+
|
| 246 |
+
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
| 247 |
+
|
| 248 |
+
custom_keys_weight_decay = []
|
| 249 |
+
if args.bias_weight_decay is not None:
|
| 250 |
+
custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
|
| 251 |
+
if args.transformer_embedding_decay is not None:
|
| 252 |
+
for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
|
| 253 |
+
custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
|
| 254 |
+
parameters = utils.set_weight_decay(
|
| 255 |
+
model,
|
| 256 |
+
args.weight_decay,
|
| 257 |
+
norm_weight_decay=args.norm_weight_decay,
|
| 258 |
+
custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
opt_name = args.opt.lower()
|
| 262 |
+
if opt_name.startswith("sgd"):
|
| 263 |
+
optimizer = torch.optim.SGD(
|
| 264 |
+
parameters,
|
| 265 |
+
lr=args.lr,
|
| 266 |
+
momentum=args.momentum,
|
| 267 |
+
weight_decay=args.weight_decay,
|
| 268 |
+
nesterov="nesterov" in opt_name,
|
| 269 |
+
)
|
| 270 |
+
elif opt_name == "rmsprop":
|
| 271 |
+
optimizer = torch.optim.RMSprop(
|
| 272 |
+
parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
|
| 273 |
+
)
|
| 274 |
+
elif opt_name == "adamw":
|
| 275 |
+
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
|
| 276 |
+
else:
|
| 277 |
+
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
|
| 278 |
+
|
| 279 |
+
scaler = torch.amp.GradScaler("cuda") if args.amp else None
|
| 280 |
+
|
| 281 |
+
args.lr_scheduler = args.lr_scheduler.lower()
|
| 282 |
+
if args.lr_scheduler == "steplr":
|
| 283 |
+
main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
|
| 284 |
+
elif args.lr_scheduler == "cosineannealinglr":
|
| 285 |
+
main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 286 |
+
optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
|
| 287 |
+
)
|
| 288 |
+
elif args.lr_scheduler == "exponentiallr":
|
| 289 |
+
main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
|
| 290 |
+
else:
|
| 291 |
+
raise RuntimeError(
|
| 292 |
+
f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
|
| 293 |
+
"are supported."
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
if args.lr_warmup_epochs > 0:
|
| 297 |
+
if args.lr_warmup_method == "linear":
|
| 298 |
+
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
|
| 299 |
+
optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
|
| 300 |
+
)
|
| 301 |
+
elif args.lr_warmup_method == "constant":
|
| 302 |
+
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
|
| 303 |
+
optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
|
| 304 |
+
)
|
| 305 |
+
else:
|
| 306 |
+
raise RuntimeError(
|
| 307 |
+
f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
|
| 308 |
+
)
|
| 309 |
+
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
|
| 310 |
+
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
|
| 311 |
+
)
|
| 312 |
+
else:
|
| 313 |
+
lr_scheduler = main_lr_scheduler
|
| 314 |
+
|
| 315 |
+
model_without_ddp = model
|
| 316 |
+
if args.distributed:
|
| 317 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
| 318 |
+
model_without_ddp = model.module
|
| 319 |
+
|
| 320 |
+
model_ema = None
|
| 321 |
+
if args.model_ema:
|
| 322 |
+
# Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
|
| 323 |
+
# https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
|
| 324 |
+
#
|
| 325 |
+
# total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
|
| 326 |
+
# We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
|
| 327 |
+
# adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
|
| 328 |
+
adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
|
| 329 |
+
alpha = 1.0 - args.model_ema_decay
|
| 330 |
+
alpha = min(1.0, alpha * adjust)
|
| 331 |
+
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
|
| 332 |
+
|
| 333 |
+
if args.resume:
|
| 334 |
+
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=False)
|
| 335 |
+
model_without_ddp.load_state_dict(checkpoint["model"])
|
| 336 |
+
if not args.test_only:
|
| 337 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
| 338 |
+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
| 339 |
+
args.start_epoch = checkpoint["epoch"] + 1
|
| 340 |
+
if model_ema:
|
| 341 |
+
model_ema.load_state_dict(checkpoint["model_ema"])
|
| 342 |
+
if scaler:
|
| 343 |
+
scaler.load_state_dict(checkpoint["scaler"])
|
| 344 |
+
|
| 345 |
+
if args.test_only:
|
| 346 |
+
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
|
| 347 |
+
torch.backends.cudnn.benchmark = False
|
| 348 |
+
torch.backends.cudnn.deterministic = True
|
| 349 |
+
if model_ema:
|
| 350 |
+
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
|
| 351 |
+
else:
|
| 352 |
+
evaluate(model, criterion, data_loader_test, device=device)
|
| 353 |
+
return
|
| 354 |
+
|
| 355 |
+
print("Start training")
|
| 356 |
+
start_time = time.time()
|
| 357 |
+
for epoch in range(args.start_epoch, args.epochs):
|
| 358 |
+
if args.distributed:
|
| 359 |
+
train_sampler.set_epoch(epoch)
|
| 360 |
+
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
|
| 361 |
+
lr_scheduler.step()
|
| 362 |
+
evaluate(model, criterion, data_loader_test, device=device)
|
| 363 |
+
if model_ema:
|
| 364 |
+
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
|
| 365 |
+
if args.output_dir:
|
| 366 |
+
checkpoint = {
|
| 367 |
+
"model": model_without_ddp.state_dict() if not args.apply_trp else {k: v for k, v in model_without_ddp.state_dict().items() if not k.startswith("trp_blocks")}, # NOTE: remove TRP heads
|
| 368 |
+
"optimizer": optimizer.state_dict(),
|
| 369 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
| 370 |
+
"epoch": epoch,
|
| 371 |
+
"args": args,
|
| 372 |
+
}
|
| 373 |
+
if model_ema:
|
| 374 |
+
checkpoint["model_ema"] = model_ema.state_dict() if not args.apply_trp else {k: v for k, v in model_ema.state_dict().items() if not k.startswith("trp_blocks")} # NOTE: remove TRP heads
|
| 375 |
+
if scaler:
|
| 376 |
+
checkpoint["scaler"] = scaler.state_dict()
|
| 377 |
+
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
|
| 378 |
+
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
|
| 379 |
+
|
| 380 |
+
total_time = time.time() - start_time
|
| 381 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 382 |
+
print(f"Training time {total_time_str}")
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def get_args_parser(add_help=True):
|
| 386 |
+
import argparse
|
| 387 |
+
|
| 388 |
+
parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
|
| 389 |
+
|
| 390 |
+
parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
|
| 391 |
+
parser.add_argument("--model", default="resnet18", type=str, help="model name")
|
| 392 |
+
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
|
| 393 |
+
parser.add_argument(
|
| 394 |
+
"-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
|
| 395 |
+
)
|
| 396 |
+
parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
|
| 397 |
+
parser.add_argument(
|
| 398 |
+
"-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
|
| 399 |
+
)
|
| 400 |
+
parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
|
| 401 |
+
parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
|
| 402 |
+
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
|
| 403 |
+
parser.add_argument(
|
| 404 |
+
"--wd",
|
| 405 |
+
"--weight-decay",
|
| 406 |
+
default=1e-4,
|
| 407 |
+
type=float,
|
| 408 |
+
metavar="W",
|
| 409 |
+
help="weight decay (default: 1e-4)",
|
| 410 |
+
dest="weight_decay",
|
| 411 |
+
)
|
| 412 |
+
parser.add_argument(
|
| 413 |
+
"--norm-weight-decay",
|
| 414 |
+
default=None,
|
| 415 |
+
type=float,
|
| 416 |
+
help="weight decay for Normalization layers (default: None, same value as --wd)",
|
| 417 |
+
)
|
| 418 |
+
parser.add_argument(
|
| 419 |
+
"--bias-weight-decay",
|
| 420 |
+
default=None,
|
| 421 |
+
type=float,
|
| 422 |
+
help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
|
| 423 |
+
)
|
| 424 |
+
parser.add_argument(
|
| 425 |
+
"--transformer-embedding-decay",
|
| 426 |
+
default=None,
|
| 427 |
+
type=float,
|
| 428 |
+
help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
|
| 429 |
+
)
|
| 430 |
+
parser.add_argument(
|
| 431 |
+
"--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
|
| 432 |
+
)
|
| 433 |
+
parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
|
| 434 |
+
parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
|
| 435 |
+
parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
|
| 436 |
+
parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
|
| 437 |
+
parser.add_argument(
|
| 438 |
+
"--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
|
| 439 |
+
)
|
| 440 |
+
parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
|
| 441 |
+
parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
|
| 442 |
+
parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
|
| 443 |
+
parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
|
| 444 |
+
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
|
| 445 |
+
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
|
| 446 |
+
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
|
| 447 |
+
parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
|
| 448 |
+
parser.add_argument(
|
| 449 |
+
"--cache-dataset",
|
| 450 |
+
dest="cache_dataset",
|
| 451 |
+
help="Cache the datasets for quicker initialization. It also serializes the transforms",
|
| 452 |
+
action="store_true",
|
| 453 |
+
)
|
| 454 |
+
parser.add_argument(
|
| 455 |
+
"--sync-bn",
|
| 456 |
+
dest="sync_bn",
|
| 457 |
+
help="Use sync batch norm",
|
| 458 |
+
action="store_true",
|
| 459 |
+
)
|
| 460 |
+
parser.add_argument(
|
| 461 |
+
"--test-only",
|
| 462 |
+
dest="test_only",
|
| 463 |
+
help="Only test the model",
|
| 464 |
+
action="store_true",
|
| 465 |
+
)
|
| 466 |
+
parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
|
| 467 |
+
parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
|
| 468 |
+
parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
|
| 469 |
+
parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
|
| 470 |
+
|
| 471 |
+
# Mixed precision training parameters
|
| 472 |
+
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
|
| 473 |
+
|
| 474 |
+
# distributed training parameters
|
| 475 |
+
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
|
| 476 |
+
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
|
| 477 |
+
parser.add_argument(
|
| 478 |
+
"--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
|
| 479 |
+
)
|
| 480 |
+
parser.add_argument(
|
| 481 |
+
"--model-ema-steps",
|
| 482 |
+
type=int,
|
| 483 |
+
default=32,
|
| 484 |
+
help="the number of iterations that controls how often to update the EMA model (default: 32)",
|
| 485 |
+
)
|
| 486 |
+
parser.add_argument(
|
| 487 |
+
"--model-ema-decay",
|
| 488 |
+
type=float,
|
| 489 |
+
default=0.99998,
|
| 490 |
+
help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
|
| 491 |
+
)
|
| 492 |
+
parser.add_argument(
|
| 493 |
+
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
|
| 494 |
+
)
|
| 495 |
+
parser.add_argument(
|
| 496 |
+
"--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
|
| 497 |
+
)
|
| 498 |
+
parser.add_argument(
|
| 499 |
+
"--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
|
| 500 |
+
)
|
| 501 |
+
parser.add_argument(
|
| 502 |
+
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
|
| 503 |
+
)
|
| 504 |
+
parser.add_argument(
|
| 505 |
+
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
|
| 506 |
+
)
|
| 507 |
+
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
|
| 508 |
+
parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
|
| 509 |
+
parser.add_argument(
|
| 510 |
+
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
|
| 511 |
+
)
|
| 512 |
+
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
|
| 513 |
+
|
| 514 |
+
parser.add_argument("--apply-trp", action="store_true", help="enable applying trp")
|
| 515 |
+
parser.add_argument("--trp-depths", type=int, help="trp depth")
|
| 516 |
+
parser.add_argument("--trp-p", type=float, help="trp p")
|
| 517 |
+
parser.add_argument("--trp-lambdas", nargs="+", type=float, help="trp lambdas")
|
| 518 |
+
|
| 519 |
+
return parser
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
if __name__ == "__main__":
|
| 523 |
+
args = get_args_parser().parse_args()
|
| 524 |
+
main(args)
|
hpo-examples/image-classification/train_quantization.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import datetime
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.ao.quantization
|
| 8 |
+
import torch.utils.data
|
| 9 |
+
import torchvision
|
| 10 |
+
import utils
|
| 11 |
+
from torch import nn
|
| 12 |
+
from train import evaluate, load_data, train_one_epoch
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def main(args):
|
| 16 |
+
if args.output_dir:
|
| 17 |
+
utils.mkdir(args.output_dir)
|
| 18 |
+
|
| 19 |
+
utils.init_distributed_mode(args)
|
| 20 |
+
print(args)
|
| 21 |
+
|
| 22 |
+
if args.post_training_quantize and args.distributed:
|
| 23 |
+
raise RuntimeError("Post training quantization example should not be performed on distributed mode")
|
| 24 |
+
|
| 25 |
+
# Set backend engine to ensure that quantized model runs on the correct kernels
|
| 26 |
+
if args.backend not in torch.backends.quantized.supported_engines:
|
| 27 |
+
raise RuntimeError("Quantized backend not supported: " + str(args.backend))
|
| 28 |
+
torch.backends.quantized.engine = args.backend
|
| 29 |
+
|
| 30 |
+
device = torch.device(args.device)
|
| 31 |
+
torch.backends.cudnn.benchmark = True
|
| 32 |
+
|
| 33 |
+
# Data loading code
|
| 34 |
+
print("Loading data")
|
| 35 |
+
train_dir = os.path.join(args.data_path, "train")
|
| 36 |
+
val_dir = os.path.join(args.data_path, "val")
|
| 37 |
+
|
| 38 |
+
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
|
| 39 |
+
data_loader = torch.utils.data.DataLoader(
|
| 40 |
+
dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
data_loader_test = torch.utils.data.DataLoader(
|
| 44 |
+
dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
print("Creating model", args.model)
|
| 48 |
+
# when training quantized models, we always start from a pre-trained fp32 reference model
|
| 49 |
+
prefix = "quantized_"
|
| 50 |
+
model_name = args.model
|
| 51 |
+
if not model_name.startswith(prefix):
|
| 52 |
+
model_name = prefix + model_name
|
| 53 |
+
model = torchvision.models.get_model(model_name, weights=args.weights, quantize=args.test_only)
|
| 54 |
+
model.to(device)
|
| 55 |
+
|
| 56 |
+
if not (args.test_only or args.post_training_quantize):
|
| 57 |
+
model.fuse_model(is_qat=True)
|
| 58 |
+
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
|
| 59 |
+
torch.ao.quantization.prepare_qat(model, inplace=True)
|
| 60 |
+
|
| 61 |
+
if args.distributed and args.sync_bn:
|
| 62 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
| 63 |
+
|
| 64 |
+
optimizer = torch.optim.SGD(
|
| 65 |
+
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
|
| 69 |
+
|
| 70 |
+
criterion = nn.CrossEntropyLoss()
|
| 71 |
+
model_without_ddp = model
|
| 72 |
+
if args.distributed:
|
| 73 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
| 74 |
+
model_without_ddp = model.module
|
| 75 |
+
|
| 76 |
+
if args.resume:
|
| 77 |
+
checkpoint = torch.load(args.resume, map_location="cpu")
|
| 78 |
+
model_without_ddp.load_state_dict(checkpoint["model"])
|
| 79 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
| 80 |
+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
| 81 |
+
args.start_epoch = checkpoint["epoch"] + 1
|
| 82 |
+
|
| 83 |
+
if args.post_training_quantize:
|
| 84 |
+
# perform calibration on a subset of the training dataset
|
| 85 |
+
# for that, create a subset of the training dataset
|
| 86 |
+
ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches)))
|
| 87 |
+
data_loader_calibration = torch.utils.data.DataLoader(
|
| 88 |
+
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
|
| 89 |
+
)
|
| 90 |
+
model.eval()
|
| 91 |
+
model.fuse_model(is_qat=False)
|
| 92 |
+
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
|
| 93 |
+
torch.ao.quantization.prepare(model, inplace=True)
|
| 94 |
+
# Calibrate first
|
| 95 |
+
print("Calibrating")
|
| 96 |
+
evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
|
| 97 |
+
torch.ao.quantization.convert(model, inplace=True)
|
| 98 |
+
if args.output_dir:
|
| 99 |
+
print("Saving quantized model")
|
| 100 |
+
if utils.is_main_process():
|
| 101 |
+
torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth"))
|
| 102 |
+
print("Evaluating post-training quantized model")
|
| 103 |
+
evaluate(model, criterion, data_loader_test, device=device)
|
| 104 |
+
return
|
| 105 |
+
|
| 106 |
+
if args.test_only:
|
| 107 |
+
evaluate(model, criterion, data_loader_test, device=device)
|
| 108 |
+
return
|
| 109 |
+
|
| 110 |
+
model.apply(torch.ao.quantization.enable_observer)
|
| 111 |
+
model.apply(torch.ao.quantization.enable_fake_quant)
|
| 112 |
+
start_time = time.time()
|
| 113 |
+
for epoch in range(args.start_epoch, args.epochs):
|
| 114 |
+
if args.distributed:
|
| 115 |
+
train_sampler.set_epoch(epoch)
|
| 116 |
+
print("Starting training for epoch", epoch)
|
| 117 |
+
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
|
| 118 |
+
lr_scheduler.step()
|
| 119 |
+
with torch.inference_mode():
|
| 120 |
+
if epoch >= args.num_observer_update_epochs:
|
| 121 |
+
print("Disabling observer for subseq epochs, epoch = ", epoch)
|
| 122 |
+
model.apply(torch.ao.quantization.disable_observer)
|
| 123 |
+
if epoch >= args.num_batch_norm_update_epochs:
|
| 124 |
+
print("Freezing BN for subseq epochs, epoch = ", epoch)
|
| 125 |
+
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
|
| 126 |
+
print("Evaluate QAT model")
|
| 127 |
+
|
| 128 |
+
evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT")
|
| 129 |
+
quantized_eval_model = copy.deepcopy(model_without_ddp)
|
| 130 |
+
quantized_eval_model.eval()
|
| 131 |
+
quantized_eval_model.to(torch.device("cpu"))
|
| 132 |
+
torch.ao.quantization.convert(quantized_eval_model, inplace=True)
|
| 133 |
+
|
| 134 |
+
print("Evaluate Quantized model")
|
| 135 |
+
evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
|
| 136 |
+
|
| 137 |
+
model.train()
|
| 138 |
+
|
| 139 |
+
if args.output_dir:
|
| 140 |
+
checkpoint = {
|
| 141 |
+
"model": model_without_ddp.state_dict(),
|
| 142 |
+
"eval_model": quantized_eval_model.state_dict(),
|
| 143 |
+
"optimizer": optimizer.state_dict(),
|
| 144 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
| 145 |
+
"epoch": epoch,
|
| 146 |
+
"args": args,
|
| 147 |
+
}
|
| 148 |
+
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
|
| 149 |
+
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
|
| 150 |
+
print("Saving models after epoch ", epoch)
|
| 151 |
+
|
| 152 |
+
total_time = time.time() - start_time
|
| 153 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 154 |
+
print(f"Training time {total_time_str}")
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def get_args_parser(add_help=True):
|
| 158 |
+
import argparse
|
| 159 |
+
|
| 160 |
+
parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help)
|
| 161 |
+
|
| 162 |
+
parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
|
| 163 |
+
parser.add_argument("--model", default="mobilenet_v2", type=str, help="model name")
|
| 164 |
+
parser.add_argument("--backend", default="qnnpack", type=str, help="fbgemm or qnnpack")
|
| 165 |
+
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
|
| 166 |
+
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
|
| 169 |
+
)
|
| 170 |
+
parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation")
|
| 171 |
+
parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"--num-observer-update-epochs",
|
| 174 |
+
default=4,
|
| 175 |
+
type=int,
|
| 176 |
+
metavar="N",
|
| 177 |
+
help="number of total epochs to update observers",
|
| 178 |
+
)
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"--num-batch-norm-update-epochs",
|
| 181 |
+
default=3,
|
| 182 |
+
type=int,
|
| 183 |
+
metavar="N",
|
| 184 |
+
help="number of total epochs to update batch norm stats",
|
| 185 |
+
)
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"--num-calibration-batches",
|
| 188 |
+
default=32,
|
| 189 |
+
type=int,
|
| 190 |
+
metavar="N",
|
| 191 |
+
help="number of batches of training set for \
|
| 192 |
+
observer calibration ",
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
parser.add_argument(
|
| 196 |
+
"-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
|
| 197 |
+
)
|
| 198 |
+
parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate")
|
| 199 |
+
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
|
| 200 |
+
parser.add_argument(
|
| 201 |
+
"--wd",
|
| 202 |
+
"--weight-decay",
|
| 203 |
+
default=1e-4,
|
| 204 |
+
type=float,
|
| 205 |
+
metavar="W",
|
| 206 |
+
help="weight decay (default: 1e-4)",
|
| 207 |
+
dest="weight_decay",
|
| 208 |
+
)
|
| 209 |
+
parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
|
| 210 |
+
parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
|
| 211 |
+
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
|
| 212 |
+
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
|
| 213 |
+
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
|
| 214 |
+
parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
|
| 215 |
+
parser.add_argument(
|
| 216 |
+
"--cache-dataset",
|
| 217 |
+
dest="cache_dataset",
|
| 218 |
+
help="Cache the datasets for quicker initialization. \
|
| 219 |
+
It also serializes the transforms",
|
| 220 |
+
action="store_true",
|
| 221 |
+
)
|
| 222 |
+
parser.add_argument(
|
| 223 |
+
"--sync-bn",
|
| 224 |
+
dest="sync_bn",
|
| 225 |
+
help="Use sync batch norm",
|
| 226 |
+
action="store_true",
|
| 227 |
+
)
|
| 228 |
+
parser.add_argument(
|
| 229 |
+
"--test-only",
|
| 230 |
+
dest="test_only",
|
| 231 |
+
help="Only test the model",
|
| 232 |
+
action="store_true",
|
| 233 |
+
)
|
| 234 |
+
parser.add_argument(
|
| 235 |
+
"--post-training-quantize",
|
| 236 |
+
dest="post_training_quantize",
|
| 237 |
+
help="Post training quantize the model",
|
| 238 |
+
action="store_true",
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
# distributed training parameters
|
| 242 |
+
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
|
| 243 |
+
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
|
| 244 |
+
|
| 245 |
+
parser.add_argument(
|
| 246 |
+
"--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
|
| 247 |
+
)
|
| 248 |
+
parser.add_argument(
|
| 249 |
+
"--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
|
| 250 |
+
)
|
| 251 |
+
parser.add_argument(
|
| 252 |
+
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
|
| 253 |
+
)
|
| 254 |
+
parser.add_argument(
|
| 255 |
+
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
|
| 256 |
+
)
|
| 257 |
+
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
|
| 258 |
+
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
|
| 259 |
+
|
| 260 |
+
return parser
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
if __name__ == "__main__":
|
| 264 |
+
args = get_args_parser().parse_args()
|
| 265 |
+
main(args)
|
hpo-examples/image-classification/transforms.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
from torchvision.transforms import functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class RandomMixup(torch.nn.Module):
|
| 10 |
+
"""Randomly apply Mixup to the provided batch and targets.
|
| 11 |
+
The class implements the data augmentations as described in the paper
|
| 12 |
+
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
num_classes (int): number of classes used for one-hot encoding.
|
| 16 |
+
p (float): probability of the batch being transformed. Default value is 0.5.
|
| 17 |
+
alpha (float): hyperparameter of the Beta distribution used for mixup.
|
| 18 |
+
Default value is 1.0.
|
| 19 |
+
inplace (bool): boolean to make this transform inplace. Default set to False.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
|
| 25 |
+
if num_classes < 1:
|
| 26 |
+
raise ValueError(
|
| 27 |
+
f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
if alpha <= 0:
|
| 31 |
+
raise ValueError("Alpha param can't be zero.")
|
| 32 |
+
|
| 33 |
+
self.num_classes = num_classes
|
| 34 |
+
self.p = p
|
| 35 |
+
self.alpha = alpha
|
| 36 |
+
self.inplace = inplace
|
| 37 |
+
|
| 38 |
+
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
|
| 39 |
+
"""
|
| 40 |
+
Args:
|
| 41 |
+
batch (Tensor): Float tensor of size (B, C, H, W)
|
| 42 |
+
target (Tensor): Integer tensor of size (B, )
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
Tensor: Randomly transformed batch.
|
| 46 |
+
"""
|
| 47 |
+
if batch.ndim != 4:
|
| 48 |
+
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
|
| 49 |
+
if target.ndim != 1:
|
| 50 |
+
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
|
| 51 |
+
if not batch.is_floating_point():
|
| 52 |
+
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
|
| 53 |
+
if target.dtype != torch.int64:
|
| 54 |
+
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
|
| 55 |
+
|
| 56 |
+
if not self.inplace:
|
| 57 |
+
batch = batch.clone()
|
| 58 |
+
target = target.clone()
|
| 59 |
+
|
| 60 |
+
if target.ndim == 1:
|
| 61 |
+
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
|
| 62 |
+
|
| 63 |
+
if torch.rand(1).item() >= self.p:
|
| 64 |
+
return batch, target
|
| 65 |
+
|
| 66 |
+
# It's faster to roll the batch by one instead of shuffling it to create image pairs
|
| 67 |
+
batch_rolled = batch.roll(1, 0)
|
| 68 |
+
target_rolled = target.roll(1, 0)
|
| 69 |
+
|
| 70 |
+
# Implemented as on mixup paper, page 3.
|
| 71 |
+
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
|
| 72 |
+
batch_rolled.mul_(1.0 - lambda_param)
|
| 73 |
+
batch.mul_(lambda_param).add_(batch_rolled)
|
| 74 |
+
|
| 75 |
+
target_rolled.mul_(1.0 - lambda_param)
|
| 76 |
+
target.mul_(lambda_param).add_(target_rolled)
|
| 77 |
+
|
| 78 |
+
return batch, target
|
| 79 |
+
|
| 80 |
+
def __repr__(self) -> str:
|
| 81 |
+
s = (
|
| 82 |
+
f"{self.__class__.__name__}("
|
| 83 |
+
f"num_classes={self.num_classes}"
|
| 84 |
+
f", p={self.p}"
|
| 85 |
+
f", alpha={self.alpha}"
|
| 86 |
+
f", inplace={self.inplace}"
|
| 87 |
+
f")"
|
| 88 |
+
)
|
| 89 |
+
return s
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class RandomCutmix(torch.nn.Module):
|
| 93 |
+
"""Randomly apply Cutmix to the provided batch and targets.
|
| 94 |
+
The class implements the data augmentations as described in the paper
|
| 95 |
+
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
|
| 96 |
+
<https://arxiv.org/abs/1905.04899>`_.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
num_classes (int): number of classes used for one-hot encoding.
|
| 100 |
+
p (float): probability of the batch being transformed. Default value is 0.5.
|
| 101 |
+
alpha (float): hyperparameter of the Beta distribution used for cutmix.
|
| 102 |
+
Default value is 1.0.
|
| 103 |
+
inplace (bool): boolean to make this transform inplace. Default set to False.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
|
| 107 |
+
super().__init__()
|
| 108 |
+
if num_classes < 1:
|
| 109 |
+
raise ValueError("Please provide a valid positive value for the num_classes.")
|
| 110 |
+
if alpha <= 0:
|
| 111 |
+
raise ValueError("Alpha param can't be zero.")
|
| 112 |
+
|
| 113 |
+
self.num_classes = num_classes
|
| 114 |
+
self.p = p
|
| 115 |
+
self.alpha = alpha
|
| 116 |
+
self.inplace = inplace
|
| 117 |
+
|
| 118 |
+
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
|
| 119 |
+
"""
|
| 120 |
+
Args:
|
| 121 |
+
batch (Tensor): Float tensor of size (B, C, H, W)
|
| 122 |
+
target (Tensor): Integer tensor of size (B, )
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
Tensor: Randomly transformed batch.
|
| 126 |
+
"""
|
| 127 |
+
if batch.ndim != 4:
|
| 128 |
+
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
|
| 129 |
+
if target.ndim != 1:
|
| 130 |
+
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
|
| 131 |
+
if not batch.is_floating_point():
|
| 132 |
+
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
|
| 133 |
+
if target.dtype != torch.int64:
|
| 134 |
+
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
|
| 135 |
+
|
| 136 |
+
if not self.inplace:
|
| 137 |
+
batch = batch.clone()
|
| 138 |
+
target = target.clone()
|
| 139 |
+
|
| 140 |
+
if target.ndim == 1:
|
| 141 |
+
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
|
| 142 |
+
|
| 143 |
+
if torch.rand(1).item() >= self.p:
|
| 144 |
+
return batch, target
|
| 145 |
+
|
| 146 |
+
# It's faster to roll the batch by one instead of shuffling it to create image pairs
|
| 147 |
+
batch_rolled = batch.roll(1, 0)
|
| 148 |
+
target_rolled = target.roll(1, 0)
|
| 149 |
+
|
| 150 |
+
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
|
| 151 |
+
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
|
| 152 |
+
_, H, W = F.get_dimensions(batch)
|
| 153 |
+
|
| 154 |
+
r_x = torch.randint(W, (1,))
|
| 155 |
+
r_y = torch.randint(H, (1,))
|
| 156 |
+
|
| 157 |
+
r = 0.5 * math.sqrt(1.0 - lambda_param)
|
| 158 |
+
r_w_half = int(r * W)
|
| 159 |
+
r_h_half = int(r * H)
|
| 160 |
+
|
| 161 |
+
x1 = int(torch.clamp(r_x - r_w_half, min=0))
|
| 162 |
+
y1 = int(torch.clamp(r_y - r_h_half, min=0))
|
| 163 |
+
x2 = int(torch.clamp(r_x + r_w_half, max=W))
|
| 164 |
+
y2 = int(torch.clamp(r_y + r_h_half, max=H))
|
| 165 |
+
|
| 166 |
+
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
|
| 167 |
+
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
|
| 168 |
+
|
| 169 |
+
target_rolled.mul_(1.0 - lambda_param)
|
| 170 |
+
target.mul_(lambda_param).add_(target_rolled)
|
| 171 |
+
|
| 172 |
+
return batch, target
|
| 173 |
+
|
| 174 |
+
def __repr__(self) -> str:
|
| 175 |
+
s = (
|
| 176 |
+
f"{self.__class__.__name__}("
|
| 177 |
+
f"num_classes={self.num_classes}"
|
| 178 |
+
f", p={self.p}"
|
| 179 |
+
f", alpha={self.alpha}"
|
| 180 |
+
f", inplace={self.inplace}"
|
| 181 |
+
f")"
|
| 182 |
+
)
|
| 183 |
+
return s
|
hpo-examples/image-classification/trplib.py
ADDED
|
@@ -0,0 +1,1181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn, Tensor
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
from torchvision.models.mobilenetv2 import MobileNetV2
|
| 6 |
+
from torchvision.models.resnet import ResNet
|
| 7 |
+
from torchvision.models.efficientnet import EfficientNet
|
| 8 |
+
from torchvision.models.vision_transformer import VisionTransformer
|
| 9 |
+
from torchvision.models.segmentation.fcn import FCN
|
| 10 |
+
from torchvision.models.segmentation.deeplabv3 import DeepLabV3
|
| 11 |
+
|
| 12 |
+
import transformers
|
| 13 |
+
from transformers.modeling_outputs import SequenceClassifierOutput, QuestionAnsweringModelOutput, CausalLMOutput, Seq2SeqLMOutput
|
| 14 |
+
|
| 15 |
+
from typing import Optional, Tuple, List, Union, Callable
|
| 16 |
+
from collections import OrderedDict
|
| 17 |
+
import types
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def trp_criterion(trp_blocks: nn.ModuleList, shared_head: Callable, criterion: Callable, lambdas: List[float], hidden_states: Tensor, logits: Tensor, targets: Tensor, loss_normalization=False):
|
| 21 |
+
loss, mask = criterion(logits, targets)
|
| 22 |
+
if loss_normalization:
|
| 23 |
+
coeff = loss.detach()
|
| 24 |
+
|
| 25 |
+
embeds = [hidden_states]
|
| 26 |
+
predictions = []
|
| 27 |
+
for k, c in enumerate(lambdas):
|
| 28 |
+
embeds.append(trp_blocks[k](embeds[-1]))
|
| 29 |
+
predictions.append(shared_head(embeds[-1]))
|
| 30 |
+
replica_loss, mask = criterion(predictions[-1], targets, mask)
|
| 31 |
+
loss += c * replica_loss
|
| 32 |
+
|
| 33 |
+
if loss_normalization:
|
| 34 |
+
with torch.no_grad():
|
| 35 |
+
coeff = torch.exp(coeff) / torch.exp(loss.detach())
|
| 36 |
+
loss = coeff * loss
|
| 37 |
+
|
| 38 |
+
return loss
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TPBlock(nn.Module):
|
| 42 |
+
def __init__(self, depths: int, in_features: int, p: float, dim=-1):
|
| 43 |
+
super(TPBlock, self).__init__()
|
| 44 |
+
|
| 45 |
+
self.dropout = nn.Dropout(p)
|
| 46 |
+
|
| 47 |
+
self.cdim = dim
|
| 48 |
+
|
| 49 |
+
blocks = []
|
| 50 |
+
for _ in range(depths):
|
| 51 |
+
blocks.append(nn.Linear(in_features, in_features))
|
| 52 |
+
nn.init.constant_(blocks[-1].weight, 0.0)
|
| 53 |
+
nn.init.constant_(blocks[-1].bias, 0.0)
|
| 54 |
+
blocks.append(nn.ReLU())
|
| 55 |
+
self.blocks = nn.Sequential(*blocks)
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
x = self.dropout(x)
|
| 59 |
+
if self.cdim == -1:
|
| 60 |
+
x = x + self.blocks(x)
|
| 61 |
+
else:
|
| 62 |
+
x = x + torch.movedim(self.blocks(torch.movedim(x, self.cdim, -1)), -1, self.cdim)
|
| 63 |
+
return x
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class Config:
|
| 67 |
+
@staticmethod
|
| 68 |
+
def gen_criterion(*args, **kwargs):
|
| 69 |
+
def func(input, target, mask=None):
|
| 70 |
+
"""
|
| 71 |
+
Args:
|
| 72 |
+
input (Tensor): Input tensor.
|
| 73 |
+
target (Tensor): Target labels.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
loss (Tensor): Scalar tensor representing the loss.
|
| 77 |
+
mask (Tensor): Boolean mask tensor with the same shape of target.
|
| 78 |
+
"""
|
| 79 |
+
pass
|
| 80 |
+
return func
|
| 81 |
+
|
| 82 |
+
@staticmethod
|
| 83 |
+
def gen_shared_head(*args, **kwargs):
|
| 84 |
+
def func(hidden_states):
|
| 85 |
+
"""
|
| 86 |
+
Args:
|
| 87 |
+
hidden_states (Tensor): Hidden States tensor.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
logits (Tensor): Logits tensor.
|
| 91 |
+
"""
|
| 92 |
+
pass
|
| 93 |
+
return func
|
| 94 |
+
|
| 95 |
+
@staticmethod
|
| 96 |
+
def forward(*args, **kwargs):
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# Wav2Vec2 for Audio Classification
|
| 101 |
+
class Wav2Vec2ForSequenceClassificationConfig(Config):
|
| 102 |
+
_HIDDEN_STATES_START_POSITION = 2
|
| 103 |
+
|
| 104 |
+
@staticmethod
|
| 105 |
+
def gen_criterion():
|
| 106 |
+
def func(input, target, mask=None):
|
| 107 |
+
"""
|
| 108 |
+
Args:
|
| 109 |
+
input (Tensor): Input tensor of shape [B, C].
|
| 110 |
+
target (Tensor): Target labels of shape [B].
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
loss (Tensor): Scalar tensor representing the loss.
|
| 114 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
| 115 |
+
"""
|
| 116 |
+
if mask is None:
|
| 117 |
+
mask = torch.ones_like(target, dtype=torch.float32, device=target.device)
|
| 118 |
+
|
| 119 |
+
unmasked_loss = F.cross_entropy(input, target, reduction="none")
|
| 120 |
+
loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
|
| 121 |
+
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
mask = mask * torch.eq(torch.argmax(input, dim=1), target).to(input.dtype)
|
| 124 |
+
|
| 125 |
+
return loss, mask
|
| 126 |
+
return func
|
| 127 |
+
|
| 128 |
+
@staticmethod
|
| 129 |
+
def gen_shared_head(self, attention_mask):
|
| 130 |
+
def func(hidden_states):
|
| 131 |
+
"""
|
| 132 |
+
Args:
|
| 133 |
+
hidden_states (Tensor): Hidden States of shape [B, L, hidden_units].
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
| 137 |
+
"""
|
| 138 |
+
_hidden_states = self.projector(hidden_states)
|
| 139 |
+
if attention_mask is None:
|
| 140 |
+
pooled_output = _hidden_states.mean(dim=1)
|
| 141 |
+
else:
|
| 142 |
+
padding_mask = self._get_feature_vector_attention_mask(_hidden_states.shape[1], attention_mask)
|
| 143 |
+
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, _hidden_states.shape[2])
|
| 144 |
+
_hidden_states[~expand_padding_mask] = 0.0
|
| 145 |
+
pooled_output = _hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
| 146 |
+
|
| 147 |
+
logits = self.classifier(pooled_output)
|
| 148 |
+
return logits
|
| 149 |
+
return func
|
| 150 |
+
|
| 151 |
+
@staticmethod
|
| 152 |
+
def gen_forward(lambdas, loss_normalization=False):
|
| 153 |
+
def func(
|
| 154 |
+
self,
|
| 155 |
+
input_values: Optional[torch.Tensor],
|
| 156 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 157 |
+
output_attentions: Optional[bool] = None,
|
| 158 |
+
output_hidden_states: Optional[bool] = None,
|
| 159 |
+
return_dict: Optional[bool] = None,
|
| 160 |
+
labels: Optional[torch.Tensor] = None,
|
| 161 |
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
| 162 |
+
r"""
|
| 163 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 164 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 165 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 166 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 170 |
+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
| 171 |
+
|
| 172 |
+
outputs = self.wav2vec2(
|
| 173 |
+
input_values,
|
| 174 |
+
attention_mask=attention_mask,
|
| 175 |
+
output_attentions=output_attentions,
|
| 176 |
+
output_hidden_states=output_hidden_states,
|
| 177 |
+
return_dict=return_dict,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
if self.config.use_weighted_layer_sum:
|
| 181 |
+
hidden_states = outputs[Wav2Vec2ForSequenceClassificationConfig._HIDDEN_STATES_START_POSITION]
|
| 182 |
+
hidden_states = torch.stack(hidden_states, dim=1)
|
| 183 |
+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
| 184 |
+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
| 185 |
+
else:
|
| 186 |
+
hidden_states = outputs[0]
|
| 187 |
+
|
| 188 |
+
_hidden_states = self.projector(hidden_states)
|
| 189 |
+
if attention_mask is None:
|
| 190 |
+
pooled_output = _hidden_states.mean(dim=1)
|
| 191 |
+
else:
|
| 192 |
+
padding_mask = self._get_feature_vector_attention_mask(_hidden_states.shape[1], attention_mask)
|
| 193 |
+
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, _hidden_states.shape[2])
|
| 194 |
+
_hidden_states[~expand_padding_mask] = 0.0
|
| 195 |
+
pooled_output = _hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
| 196 |
+
|
| 197 |
+
logits = self.classifier(pooled_output)
|
| 198 |
+
|
| 199 |
+
loss = None
|
| 200 |
+
if labels is not None:
|
| 201 |
+
shared_head = Wav2Vec2ForSequenceClassificationConfig.gen_shared_head(self, attention_mask)
|
| 202 |
+
criterion = Wav2Vec2ForSequenceClassificationConfig.gen_criterion()
|
| 203 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_states, logits.view(-1, self.config.num_labels), labels.view(-1), loss_normalization) # NOTE: Apply TRP!
|
| 204 |
+
|
| 205 |
+
if not return_dict:
|
| 206 |
+
output = (logits,) + outputs[Wav2Vec2ForSequenceClassificationConfig._HIDDEN_STATES_START_POSITION:]
|
| 207 |
+
return ((loss,) + output) if loss is not None else output
|
| 208 |
+
|
| 209 |
+
return SequenceClassifierOutput(
|
| 210 |
+
loss=loss,
|
| 211 |
+
logits=logits,
|
| 212 |
+
hidden_states=outputs.hidden_states,
|
| 213 |
+
attentions=outputs.attentions,
|
| 214 |
+
)
|
| 215 |
+
return func
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# MobileNetV2 for Image Classification
|
| 219 |
+
class MobileNetV2Config(Config):
|
| 220 |
+
@staticmethod
|
| 221 |
+
def gen_criterion(label_smoothing=0.0, top_k=1):
|
| 222 |
+
def func(input, target, mask=None):
|
| 223 |
+
"""
|
| 224 |
+
Args:
|
| 225 |
+
input (Tensor): Input tensor of shape [B, C].
|
| 226 |
+
target (Tensor): Target labels of shape [B] or [B, C].
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
loss (Tensor): Scalar tensor representing the loss.
|
| 230 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
| 231 |
+
"""
|
| 232 |
+
label = torch.argmax(target, dim=1) if label_smoothing > 0.0 else target
|
| 233 |
+
|
| 234 |
+
unmasked_loss = F.cross_entropy(input, label, reduction="none", label_smoothing=label_smoothing)
|
| 235 |
+
if mask is None:
|
| 236 |
+
mask = torch.ones_like(unmasked_loss, dtype=torch.float32, device=target.device)
|
| 237 |
+
loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
|
| 238 |
+
|
| 239 |
+
with torch.no_grad():
|
| 240 |
+
topk_values, topk_indices = torch.topk(input, top_k, dim=-1)
|
| 241 |
+
mask = mask * torch.eq(topk_indices, label[:, None]).any(dim=-1).to(input.dtype)
|
| 242 |
+
|
| 243 |
+
return loss, mask
|
| 244 |
+
return func
|
| 245 |
+
|
| 246 |
+
@staticmethod
|
| 247 |
+
def gen_shared_head(self):
|
| 248 |
+
def func(x):
|
| 249 |
+
"""
|
| 250 |
+
Args:
|
| 251 |
+
x (Tensor): Hidden States tensor of shape [B, hidden_units].
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
| 255 |
+
"""
|
| 256 |
+
logits = self.classifier(x)
|
| 257 |
+
return logits
|
| 258 |
+
return func
|
| 259 |
+
|
| 260 |
+
@staticmethod
|
| 261 |
+
def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
|
| 262 |
+
def func(self, images: Tensor, targets=None):
|
| 263 |
+
x = self.features(images)
|
| 264 |
+
x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
|
| 265 |
+
x = torch.flatten(x, 1)
|
| 266 |
+
logits = self.classifier(x)
|
| 267 |
+
|
| 268 |
+
if self.training:
|
| 269 |
+
torch._assert(targets is not None, "targets should not be none when in training mode")
|
| 270 |
+
shared_head = MobileNetV2Config.gen_shared_head(self)
|
| 271 |
+
criterion = MobileNetV2Config.gen_criterion(label_smoothing, top_k)
|
| 272 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
|
| 273 |
+
return logits, loss
|
| 274 |
+
return logits
|
| 275 |
+
return func
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# ResNet for Image Classification
|
| 279 |
+
class ResNetConfig(MobileNetV2Config):
|
| 280 |
+
@staticmethod
|
| 281 |
+
def gen_shared_head(self):
|
| 282 |
+
def func(x):
|
| 283 |
+
"""
|
| 284 |
+
Args:
|
| 285 |
+
x (Tensor): Hidden States tensor of shape [B, hidden_units].
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
| 289 |
+
"""
|
| 290 |
+
logits = self.fc(x)
|
| 291 |
+
return logits
|
| 292 |
+
return func
|
| 293 |
+
|
| 294 |
+
@staticmethod
|
| 295 |
+
def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
|
| 296 |
+
def func(self, images: Tensor, targets=None):
|
| 297 |
+
x = self.conv1(images)
|
| 298 |
+
x = self.bn1(x)
|
| 299 |
+
x = self.relu(x)
|
| 300 |
+
x = self.maxpool(x)
|
| 301 |
+
|
| 302 |
+
x = self.layer1(x)
|
| 303 |
+
x = self.layer2(x)
|
| 304 |
+
x = self.layer3(x)
|
| 305 |
+
x = self.layer4(x)
|
| 306 |
+
|
| 307 |
+
x = self.avgpool(x)
|
| 308 |
+
x = torch.flatten(x, 1)
|
| 309 |
+
logits = self.fc(x)
|
| 310 |
+
|
| 311 |
+
if self.training:
|
| 312 |
+
torch._assert(targets is not None, "targets should not be none when in training mode")
|
| 313 |
+
shared_head = ResNetConfig.gen_shared_head(self)
|
| 314 |
+
criterion = ResNetConfig.gen_criterion(label_smoothing, top_k)
|
| 315 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
|
| 316 |
+
return logits, loss
|
| 317 |
+
return logits
|
| 318 |
+
return func
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# EfficientNet for Image Classification
|
| 322 |
+
class EfficientNetConfig(MobileNetV2Config):
|
| 323 |
+
@staticmethod
|
| 324 |
+
def gen_shared_head(self):
|
| 325 |
+
def func(x):
|
| 326 |
+
"""
|
| 327 |
+
Args:
|
| 328 |
+
x (Tensor): Hidden States tensor of shape [B, hidden_units].
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
| 332 |
+
"""
|
| 333 |
+
logits = self.classifier(x)
|
| 334 |
+
return logits
|
| 335 |
+
return func
|
| 336 |
+
|
| 337 |
+
@staticmethod
|
| 338 |
+
def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
|
| 339 |
+
def func(self, images: Tensor, targets=None):
|
| 340 |
+
x = self.features(images)
|
| 341 |
+
x = self.avgpool(x)
|
| 342 |
+
x = torch.flatten(x, 1)
|
| 343 |
+
logits = self.classifier(x)
|
| 344 |
+
|
| 345 |
+
if self.training:
|
| 346 |
+
torch._assert(targets is not None, "targets should not be none when in training mode")
|
| 347 |
+
shared_head = EfficientNetConfig.gen_shared_head(self)
|
| 348 |
+
criterion = EfficientNetConfig.gen_criterion(label_smoothing, top_k)
|
| 349 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
|
| 350 |
+
return logits, loss
|
| 351 |
+
return logits
|
| 352 |
+
return func
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
# ViT for Image Classification
|
| 356 |
+
class VisionTransformerConfig(MobileNetV2Config):
|
| 357 |
+
@staticmethod
|
| 358 |
+
def gen_shared_head(self):
|
| 359 |
+
def func(x):
|
| 360 |
+
"""
|
| 361 |
+
Args:
|
| 362 |
+
x (Tensor): Hidden States tensor of shape [B, hidden_units].
|
| 363 |
+
|
| 364 |
+
Returns:
|
| 365 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
| 366 |
+
"""
|
| 367 |
+
logits = self.heads(x)
|
| 368 |
+
return logits
|
| 369 |
+
return func
|
| 370 |
+
|
| 371 |
+
@staticmethod
|
| 372 |
+
def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
|
| 373 |
+
def func(self, images: Tensor, targets=None):
|
| 374 |
+
x = self._process_input(images)
|
| 375 |
+
n = x.shape[0]
|
| 376 |
+
batch_class_token = self.class_token.expand(n, -1, -1)
|
| 377 |
+
x = torch.cat([batch_class_token, x], dim=1)
|
| 378 |
+
x = self.encoder(x)
|
| 379 |
+
x = x[:, 0]
|
| 380 |
+
|
| 381 |
+
logits = self.heads(x)
|
| 382 |
+
|
| 383 |
+
if self.training:
|
| 384 |
+
torch._assert(targets is not None, "targets should not be none when in training mode")
|
| 385 |
+
shared_head = VisionTransformerConfig.gen_shared_head(self)
|
| 386 |
+
criterion = VisionTransformerConfig.gen_criterion(label_smoothing, top_k)
|
| 387 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
|
| 388 |
+
return logits, loss
|
| 389 |
+
return logits
|
| 390 |
+
return func
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
# Bert for Question Answering
|
| 394 |
+
class BertForQuestionAnsweringConfig(Config):
|
| 395 |
+
@staticmethod
|
| 396 |
+
def gen_criterion(top_k=1):
|
| 397 |
+
def func(input, target: List[Tensor], mask=None):
|
| 398 |
+
"""
|
| 399 |
+
Args:
|
| 400 |
+
input (Tensor): Input tensor of shape [B, C, 2].
|
| 401 |
+
target (List[Tensor]):
|
| 402 |
+
Start Positions of shape [B].
|
| 403 |
+
End Positions of shape [B].
|
| 404 |
+
|
| 405 |
+
Returns:
|
| 406 |
+
loss (Tensor): Scalar tensor representing the loss.
|
| 407 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
| 408 |
+
"""
|
| 409 |
+
start_positions, end_positions = target
|
| 410 |
+
|
| 411 |
+
if mask is None:
|
| 412 |
+
mask = torch.ones_like(start_positions, dtype=torch.float32, device=start_positions.device)
|
| 413 |
+
|
| 414 |
+
start_logits, end_logits = input.split(1, dim=-1)
|
| 415 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
| 416 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
| 417 |
+
|
| 418 |
+
# If we are on multi-GPU, split add a dimension
|
| 419 |
+
if len(start_positions.size()) > 1:
|
| 420 |
+
start_positions = start_positions.squeeze(-1)
|
| 421 |
+
if len(end_positions.size()) > 1:
|
| 422 |
+
end_positions = end_positions.squeeze(-1)
|
| 423 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 424 |
+
ignored_index = start_logits.size(1)
|
| 425 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 426 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 427 |
+
|
| 428 |
+
masked_start_losses = F.cross_entropy(start_logits, start_positions, ignore_index=ignored_index, reduction="none")
|
| 429 |
+
start_loss = torch.sum(mask * masked_start_losses) / (torch.sum(mask) + 1e-6)
|
| 430 |
+
masked_end_losses = F.cross_entropy(end_logits, end_positions, ignore_index=ignored_index, reduction="none")
|
| 431 |
+
end_loss = torch.sum(mask * masked_end_losses) / (torch.sum(mask) + 1e-6)
|
| 432 |
+
|
| 433 |
+
with torch.no_grad():
|
| 434 |
+
topk_values, topk_indices = torch.topk(start_logits, top_k, dim=1)
|
| 435 |
+
mask = mask * torch.eq(topk_indices, start_positions[:, None]).any(dim=1).to(start_logits.dtype)
|
| 436 |
+
topk_values, topk_indices = torch.topk(end_logits, top_k, dim=1)
|
| 437 |
+
mask = mask * torch.eq(topk_indices, end_positions[:, None]).any(dim=1).to(end_logits.dtype)
|
| 438 |
+
|
| 439 |
+
return (start_loss + end_loss) / 2, mask
|
| 440 |
+
return func
|
| 441 |
+
|
| 442 |
+
@staticmethod
|
| 443 |
+
def gen_shared_head(self):
|
| 444 |
+
def func(hidden_states):
|
| 445 |
+
"""
|
| 446 |
+
Args:
|
| 447 |
+
hidden_states (Tensor): Hidden States of shape [B, C, hidden_units].
|
| 448 |
+
|
| 449 |
+
Returns:
|
| 450 |
+
logits (Tensor): Logits tensor of shape [B, C, 2].
|
| 451 |
+
"""
|
| 452 |
+
logits = self.qa_outputs(hidden_states)
|
| 453 |
+
return logits
|
| 454 |
+
return func
|
| 455 |
+
|
| 456 |
+
@staticmethod
|
| 457 |
+
def gen_forward(lambdas, loss_normalization=True, top_k=1):
|
| 458 |
+
def func(
|
| 459 |
+
self,
|
| 460 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 461 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 462 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 463 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 464 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 465 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 466 |
+
start_positions: Optional[torch.Tensor] = None,
|
| 467 |
+
end_positions: Optional[torch.Tensor] = None,
|
| 468 |
+
output_attentions: Optional[bool] = None,
|
| 469 |
+
output_hidden_states: Optional[bool] = None,
|
| 470 |
+
return_dict: Optional[bool] = None,
|
| 471 |
+
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
| 472 |
+
r"""
|
| 473 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 474 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 475 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 476 |
+
are not taken into account for computing the loss.
|
| 477 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 478 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 479 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 480 |
+
are not taken into account for computing the loss.
|
| 481 |
+
"""
|
| 482 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 483 |
+
|
| 484 |
+
outputs = self.bert(
|
| 485 |
+
input_ids,
|
| 486 |
+
attention_mask=attention_mask,
|
| 487 |
+
token_type_ids=token_type_ids,
|
| 488 |
+
position_ids=position_ids,
|
| 489 |
+
head_mask=head_mask,
|
| 490 |
+
inputs_embeds=inputs_embeds,
|
| 491 |
+
output_attentions=output_attentions,
|
| 492 |
+
output_hidden_states=output_hidden_states,
|
| 493 |
+
return_dict=return_dict,
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
sequence_output = outputs[0]
|
| 497 |
+
|
| 498 |
+
logits = self.qa_outputs(sequence_output)
|
| 499 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 500 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
| 501 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
| 502 |
+
|
| 503 |
+
total_loss = None
|
| 504 |
+
if start_positions is not None and end_positions is not None:
|
| 505 |
+
shared_head = BertForQuestionAnsweringConfig.gen_shared_head(self)
|
| 506 |
+
criterion = BertForQuestionAnsweringConfig.gen_criterion()
|
| 507 |
+
total_loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, sequence_output, logits, [start_positions, end_positions], loss_normalization) # NOTE: Apply TRP!
|
| 508 |
+
|
| 509 |
+
if not return_dict:
|
| 510 |
+
output = (start_logits, end_logits) + outputs[2:]
|
| 511 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 512 |
+
|
| 513 |
+
return QuestionAnsweringModelOutput(
|
| 514 |
+
loss=total_loss,
|
| 515 |
+
start_logits=start_logits,
|
| 516 |
+
end_logits=end_logits,
|
| 517 |
+
hidden_states=outputs.hidden_states,
|
| 518 |
+
attentions=outputs.attentions,
|
| 519 |
+
)
|
| 520 |
+
return func
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
# FCN for Semantic Segmentation
|
| 524 |
+
class FCNConfig(Config):
|
| 525 |
+
@staticmethod
|
| 526 |
+
def gen_criterion(top_k=1):
|
| 527 |
+
def func(input, target, mask=None):
|
| 528 |
+
"""
|
| 529 |
+
Args:
|
| 530 |
+
input Tensor: input tensor of shape [B, C, H, W].
|
| 531 |
+
target (Tensor): Target labels of shape [B, H, W].
|
| 532 |
+
|
| 533 |
+
Returns:
|
| 534 |
+
loss (Tensor): Scalar tensor representing the loss.
|
| 535 |
+
mask (Tensor): Boolean mask tensor of shape [B, H, W].
|
| 536 |
+
"""
|
| 537 |
+
if mask is None:
|
| 538 |
+
mask = torch.ones_like(target, dtype=torch.float32, device=target.device)
|
| 539 |
+
|
| 540 |
+
masked_loss = F.cross_entropy(input, target, ignore_index=255, reduction="none")
|
| 541 |
+
loss = torch.sum(mask * masked_loss) / (torch.sum(mask) + 1e-6)
|
| 542 |
+
|
| 543 |
+
with torch.no_grad():
|
| 544 |
+
topk_values, topk_indices = torch.topk(input, top_k, dim=1)
|
| 545 |
+
mask = mask * torch.eq(topk_indices, target[:, None, :, :]).any(dim=1).to(input.dtype)
|
| 546 |
+
# mask = mask * torch.eq(torch.argmax(x, dim=1), target).to(x.dtype)
|
| 547 |
+
|
| 548 |
+
return loss, mask
|
| 549 |
+
return func
|
| 550 |
+
|
| 551 |
+
@staticmethod
|
| 552 |
+
def gen_out_shared_head(self, input_shape):
|
| 553 |
+
def func(features):
|
| 554 |
+
"""
|
| 555 |
+
Args:
|
| 556 |
+
features (Tensor): features tensor of shape [B, hidden_units, H, W].
|
| 557 |
+
|
| 558 |
+
Returns:
|
| 559 |
+
result (Tensors): result tensor of shape [B, C, H, W].
|
| 560 |
+
"""
|
| 561 |
+
x = self.classifier(features)
|
| 562 |
+
result = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
| 563 |
+
return result
|
| 564 |
+
return func
|
| 565 |
+
|
| 566 |
+
@staticmethod
|
| 567 |
+
def gen_aux_shared_head(self, input_shape):
|
| 568 |
+
def func(features):
|
| 569 |
+
"""
|
| 570 |
+
Args:
|
| 571 |
+
features (Tensor): features tensor of shape [B, hidden_units, H, W].
|
| 572 |
+
|
| 573 |
+
Returns:
|
| 574 |
+
result (Tensors): result tensor of shape [B, C, H, W].
|
| 575 |
+
"""
|
| 576 |
+
x = self.aux_classifier(features)
|
| 577 |
+
result = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
| 578 |
+
return result
|
| 579 |
+
return func
|
| 580 |
+
|
| 581 |
+
@staticmethod
|
| 582 |
+
def gen_forward(lambdas, loss_normalization=True, top_k=1):
|
| 583 |
+
def func(self, images: Tensor, targets=None):
|
| 584 |
+
input_shape = images.shape[-2:]
|
| 585 |
+
# contract: features is a dict of tensors
|
| 586 |
+
features = self.backbone(images)
|
| 587 |
+
|
| 588 |
+
result = OrderedDict()
|
| 589 |
+
x = features["out"]
|
| 590 |
+
x = self.classifier(x)
|
| 591 |
+
x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
| 592 |
+
result["out"] = x
|
| 593 |
+
|
| 594 |
+
if self.aux_classifier is not None:
|
| 595 |
+
x = features["aux"]
|
| 596 |
+
x = self.aux_classifier(x)
|
| 597 |
+
x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
| 598 |
+
result["aux"] = x
|
| 599 |
+
|
| 600 |
+
if self.training:
|
| 601 |
+
torch._assert(targets is not None, "targets should not be none when in training mode")
|
| 602 |
+
out_shared_head = FCNConfig.gen_out_shared_head(self, input_shape)
|
| 603 |
+
aux_shared_head = FCNConfig.gen_aux_shared_head(self, input_shape)
|
| 604 |
+
criterion = FCNConfig.gen_criterion(top_k)
|
| 605 |
+
out_loss = trp_criterion(self.out_trp_blocks, out_shared_head, criterion, lambdas, features["out"], result["out"], targets, loss_normalization)
|
| 606 |
+
aux_loss = trp_criterion(self.aux_trp_blocks, aux_shared_head, criterion, lambdas, features["aux"], result["aux"], targets, loss_normalization)
|
| 607 |
+
loss = out_loss + 0.5 * aux_loss
|
| 608 |
+
return result, loss
|
| 609 |
+
return result
|
| 610 |
+
return func
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
# DeepLabV3Config for Semantic Segmentation
|
| 614 |
+
class DeepLabV3Config(FCNConfig):
|
| 615 |
+
pass
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
# Bert for Text Classification
|
| 619 |
+
class BertForSequenceClassificationConfig(Config):
|
| 620 |
+
@staticmethod
|
| 621 |
+
def gen_criterion():
|
| 622 |
+
def func(input, target, mask=None):
|
| 623 |
+
"""
|
| 624 |
+
Args:
|
| 625 |
+
input (Tensor): Input tensor of shape [B, C].
|
| 626 |
+
target (Tensor): Target labels of shape [B].
|
| 627 |
+
|
| 628 |
+
Returns:
|
| 629 |
+
loss (Tensor): Scalar tensor representing the loss.
|
| 630 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
| 631 |
+
"""
|
| 632 |
+
if mask is None:
|
| 633 |
+
mask = torch.ones_like(target, dtype=torch.float32, device=target.device)
|
| 634 |
+
|
| 635 |
+
unmasked_loss = F.cross_entropy(input, target, reduction="none")
|
| 636 |
+
loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
|
| 637 |
+
|
| 638 |
+
with torch.no_grad():
|
| 639 |
+
mask = mask * torch.eq(torch.argmax(input, dim=1), target).to(input.dtype)
|
| 640 |
+
|
| 641 |
+
return loss, mask
|
| 642 |
+
return func
|
| 643 |
+
|
| 644 |
+
@staticmethod
|
| 645 |
+
def gen_shared_head(self):
|
| 646 |
+
def func(hidden_states):
|
| 647 |
+
"""
|
| 648 |
+
Args:
|
| 649 |
+
hidden_states (Tensor): Hidden States of shape [B, hidden_units].
|
| 650 |
+
|
| 651 |
+
Returns:
|
| 652 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
| 653 |
+
"""
|
| 654 |
+
logits = self.classifier(hidden_states)
|
| 655 |
+
return logits
|
| 656 |
+
return func
|
| 657 |
+
|
| 658 |
+
@staticmethod
|
| 659 |
+
def gen_forward(lambdas, loss_normalization=False):
|
| 660 |
+
def func(
|
| 661 |
+
self,
|
| 662 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 663 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 664 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 665 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 666 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 667 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 668 |
+
labels: Optional[torch.Tensor] = None,
|
| 669 |
+
output_attentions: Optional[bool] = None,
|
| 670 |
+
output_hidden_states: Optional[bool] = None,
|
| 671 |
+
return_dict: Optional[bool] = None,
|
| 672 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 673 |
+
r"""
|
| 674 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 675 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 676 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 677 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 678 |
+
"""
|
| 679 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 680 |
+
|
| 681 |
+
outputs = self.bert(
|
| 682 |
+
input_ids,
|
| 683 |
+
attention_mask=attention_mask,
|
| 684 |
+
token_type_ids=token_type_ids,
|
| 685 |
+
position_ids=position_ids,
|
| 686 |
+
head_mask=head_mask,
|
| 687 |
+
inputs_embeds=inputs_embeds,
|
| 688 |
+
output_attentions=output_attentions,
|
| 689 |
+
output_hidden_states=output_hidden_states,
|
| 690 |
+
return_dict=return_dict,
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
pooled_output = outputs[1]
|
| 694 |
+
|
| 695 |
+
pooled_output = self.dropout(pooled_output)
|
| 696 |
+
logits = self.classifier(pooled_output)
|
| 697 |
+
|
| 698 |
+
loss = None
|
| 699 |
+
if labels is not None:
|
| 700 |
+
assert self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int) # TODO: remove this
|
| 701 |
+
if self.config.problem_type is None:
|
| 702 |
+
if self.num_labels == 1:
|
| 703 |
+
self.config.problem_type = "regression"
|
| 704 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 705 |
+
self.config.problem_type = "single_label_classification"
|
| 706 |
+
else:
|
| 707 |
+
self.config.problem_type = "multi_label_classification"
|
| 708 |
+
|
| 709 |
+
if self.config.problem_type == "regression":
|
| 710 |
+
if self.num_labels == 1:
|
| 711 |
+
loss = F.mse_loss(logits.squeeze(), labels.squeeze())
|
| 712 |
+
else:
|
| 713 |
+
loss = F.mse_loss(logits, labels)
|
| 714 |
+
elif self.config.problem_type == "single_label_classification":
|
| 715 |
+
shared_head = BertForSequenceClassificationConfig.gen_shared_head(self)
|
| 716 |
+
criterion = BertForSequenceClassificationConfig.gen_criterion()
|
| 717 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, pooled_output, logits, labels, loss_normalization)
|
| 718 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 719 |
+
loss = F.binary_cross_entropy_with_logits(logits, labels)
|
| 720 |
+
if not return_dict:
|
| 721 |
+
output = (logits,) + outputs[2:]
|
| 722 |
+
return ((loss,) + output) if loss is not None else output
|
| 723 |
+
|
| 724 |
+
return SequenceClassifierOutput(
|
| 725 |
+
loss=loss,
|
| 726 |
+
logits=logits,
|
| 727 |
+
hidden_states=outputs.hidden_states,
|
| 728 |
+
attentions=outputs.attentions,
|
| 729 |
+
)
|
| 730 |
+
return func
|
| 731 |
+
|
| 732 |
+
|
| 733 |
+
# Boberta for Text Classification
|
| 734 |
+
class RobertaForSequenceClassificationConfig(BertForSequenceClassificationConfig):
|
| 735 |
+
@staticmethod
|
| 736 |
+
def gen_shared_head(self):
|
| 737 |
+
def func(hidden_states):
|
| 738 |
+
"""
|
| 739 |
+
Args:
|
| 740 |
+
hidden_states (Tensor): Hidden States of shape [B, hidden_units].
|
| 741 |
+
|
| 742 |
+
Returns:
|
| 743 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
| 744 |
+
"""
|
| 745 |
+
logits = self.classifier(hidden_states)
|
| 746 |
+
return logits
|
| 747 |
+
return func
|
| 748 |
+
|
| 749 |
+
@staticmethod
|
| 750 |
+
def gen_forward(lambdas, loss_normalization=False):
|
| 751 |
+
def func(
|
| 752 |
+
self,
|
| 753 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 754 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 755 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 756 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 757 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 758 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 759 |
+
labels: Optional[torch.LongTensor] = None,
|
| 760 |
+
output_attentions: Optional[bool] = None,
|
| 761 |
+
output_hidden_states: Optional[bool] = None,
|
| 762 |
+
return_dict: Optional[bool] = None,
|
| 763 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 764 |
+
r"""
|
| 765 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 766 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 767 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 768 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 769 |
+
"""
|
| 770 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 771 |
+
|
| 772 |
+
outputs = self.roberta(
|
| 773 |
+
input_ids,
|
| 774 |
+
attention_mask=attention_mask,
|
| 775 |
+
token_type_ids=token_type_ids,
|
| 776 |
+
position_ids=position_ids,
|
| 777 |
+
head_mask=head_mask,
|
| 778 |
+
inputs_embeds=inputs_embeds,
|
| 779 |
+
output_attentions=output_attentions,
|
| 780 |
+
output_hidden_states=output_hidden_states,
|
| 781 |
+
return_dict=return_dict,
|
| 782 |
+
)
|
| 783 |
+
sequence_output = outputs[0]
|
| 784 |
+
logits = self.classifier(sequence_output)
|
| 785 |
+
|
| 786 |
+
loss = None
|
| 787 |
+
if labels is not None:
|
| 788 |
+
assert self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int) # TODO: remove this
|
| 789 |
+
# move labels to correct device to enable model parallelism
|
| 790 |
+
labels = labels.to(logits.device)
|
| 791 |
+
if self.config.problem_type is None:
|
| 792 |
+
if self.num_labels == 1:
|
| 793 |
+
self.config.problem_type = "regression"
|
| 794 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 795 |
+
self.config.problem_type = "single_label_classification"
|
| 796 |
+
else:
|
| 797 |
+
self.config.problem_type = "multi_label_classification"
|
| 798 |
+
|
| 799 |
+
if self.config.problem_type == "regression":
|
| 800 |
+
if self.num_labels == 1:
|
| 801 |
+
loss = F.mse_loss(logits.squeeze(), labels.squeeze())
|
| 802 |
+
else:
|
| 803 |
+
loss = F.mse_loss(logits, labels)
|
| 804 |
+
elif self.config.problem_type == "single_label_classification":
|
| 805 |
+
shared_head = BertForSequenceClassificationConfig.gen_shared_head(self)
|
| 806 |
+
criterion = BertForSequenceClassificationConfig.gen_criterion()
|
| 807 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, sequence_output, logits, labels, loss_normalization)
|
| 808 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 809 |
+
loss = F.binary_cross_entropy_with_logits(logits, labels)
|
| 810 |
+
|
| 811 |
+
if not return_dict:
|
| 812 |
+
output = (logits,) + outputs[2:]
|
| 813 |
+
return ((loss,) + output) if loss is not None else output
|
| 814 |
+
|
| 815 |
+
return SequenceClassifierOutput(
|
| 816 |
+
loss=loss,
|
| 817 |
+
logits=logits,
|
| 818 |
+
hidden_states=outputs.hidden_states,
|
| 819 |
+
attentions=outputs.attentions,
|
| 820 |
+
)
|
| 821 |
+
return func
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
# Wav2Vec2 for Speech Recognition
|
| 825 |
+
class Wav2Vec2ForCTCConfig(Config):
|
| 826 |
+
_HIDDEN_STATES_START_POSITION = 2
|
| 827 |
+
|
| 828 |
+
@staticmethod
|
| 829 |
+
def greedy_decode_ctc(
|
| 830 |
+
log_probs: torch.Tensor,
|
| 831 |
+
input_lengths: torch.Tensor,
|
| 832 |
+
blank_token_id: int,
|
| 833 |
+
target_lengths: torch.Tensor
|
| 834 |
+
):
|
| 835 |
+
"""
|
| 836 |
+
Convert logits to flattened predictions that match the shape of flattened_targets.
|
| 837 |
+
|
| 838 |
+
Args:
|
| 839 |
+
log_probs: [B, L, V] - log-softmax output
|
| 840 |
+
input_lengths: [B] - actual length of each input
|
| 841 |
+
blank_token_id: int - index of blank token
|
| 842 |
+
target_lengths: [B] - used to determine how many predictions to keep per sample
|
| 843 |
+
|
| 844 |
+
Returns:
|
| 845 |
+
flattened_predictions: 1D tensor, same total length as sum(target_lengths)
|
| 846 |
+
"""
|
| 847 |
+
batch_size = log_probs.size(0)
|
| 848 |
+
decoded_all = []
|
| 849 |
+
|
| 850 |
+
predicted_ids = log_probs.argmax(dim=-1) # [B, L]
|
| 851 |
+
|
| 852 |
+
for i in range(batch_size):
|
| 853 |
+
pred = predicted_ids[i][:input_lengths[i]] # [Li]
|
| 854 |
+
prev = None
|
| 855 |
+
decoded = []
|
| 856 |
+
for token in pred:
|
| 857 |
+
token = token.item()
|
| 858 |
+
if token != blank_token_id and token != prev:
|
| 859 |
+
decoded.append(token)
|
| 860 |
+
prev = token
|
| 861 |
+
# Trim or pad to match target_lengths[i]
|
| 862 |
+
tgt_len = target_lengths[i].item()
|
| 863 |
+
if len(decoded) >= tgt_len:
|
| 864 |
+
decoded = decoded[:tgt_len]
|
| 865 |
+
else:
|
| 866 |
+
decoded = decoded + [blank_token_id] * (tgt_len - len(decoded)) # pad with blank
|
| 867 |
+
decoded_all.extend(decoded)
|
| 868 |
+
|
| 869 |
+
return torch.tensor(decoded_all, dtype=torch.long, device=log_probs.device) # shape: [sum(target_lengths)]
|
| 870 |
+
|
| 871 |
+
@staticmethod
|
| 872 |
+
def gen_criterion(input_lengths: Tensor, pad_token_id: int, ctc_zero_infinity: bool):
|
| 873 |
+
def func(logits: Tensor, labels: Tensor, mask=None):
|
| 874 |
+
"""
|
| 875 |
+
Args:
|
| 876 |
+
logits (Tensor): Log Probablities of shape [B, L, V].
|
| 877 |
+
labels (Tensor): Flattened Targets of shape [B, L'].
|
| 878 |
+
|
| 879 |
+
Returns:
|
| 880 |
+
loss (Tensor): Scalar tensor representing the loss.
|
| 881 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
| 882 |
+
"""
|
| 883 |
+
if mask is None:
|
| 884 |
+
mask = torch.ones_like(input_lengths, dtype=torch.float32, device=input_lengths.device)
|
| 885 |
+
|
| 886 |
+
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
| 887 |
+
labels_mask = labels >= 0
|
| 888 |
+
target_lengths = labels_mask.sum(-1)
|
| 889 |
+
flattened_targets = labels.masked_select(labels_mask)
|
| 890 |
+
with torch.backends.cudnn.flags(enabled=False):
|
| 891 |
+
masked_losses = nn.functional.ctc_loss(log_probs, flattened_targets, input_lengths, target_lengths, blank=pad_token_id, reduction="none", zero_infinity=ctc_zero_infinity)
|
| 892 |
+
loss = torch.sum(mask * masked_losses) / (torch.sum(mask) + 1e-6)
|
| 893 |
+
|
| 894 |
+
with torch.no_grad():
|
| 895 |
+
thres = 0.5
|
| 896 |
+
flattened_predictions = Wav2Vec2ForCTCConfig.greedy_decode_ctc(
|
| 897 |
+
log_probs.transpose(0, 1), # [B, T, V]
|
| 898 |
+
input_lengths=input_lengths,
|
| 899 |
+
blank_token_id=pad_token_id,
|
| 900 |
+
target_lengths=target_lengths
|
| 901 |
+
)
|
| 902 |
+
token_wise_mask = torch.eq(flattened_predictions, flattened_targets).to(flattened_targets.dtype)
|
| 903 |
+
segment_ids = torch.arange(len(target_lengths), device=target_lengths.device).repeat_interleave(target_lengths)
|
| 904 |
+
sequence_wise_mask = torch.zeros(len(target_lengths), dtype=target_lengths.dtype, device=token_wise_mask.device).scatter_add(0, segment_ids, token_wise_mask)
|
| 905 |
+
mask = mask * torch.ge(sequence_wise_mask, thres * target_lengths).to(flattened_targets.dtype)
|
| 906 |
+
|
| 907 |
+
return loss, mask
|
| 908 |
+
return func
|
| 909 |
+
|
| 910 |
+
@staticmethod
|
| 911 |
+
def gen_shared_head(self):
|
| 912 |
+
def func(hidden_states):
|
| 913 |
+
"""
|
| 914 |
+
Args:
|
| 915 |
+
hidden_states (Tensor): Hidden States of shape [B, C, hidden_units].
|
| 916 |
+
|
| 917 |
+
Returns:
|
| 918 |
+
logits (Tensor): Logits tensor of shape [B, C, 2].
|
| 919 |
+
"""
|
| 920 |
+
logits = self.lm_head(hidden_states)
|
| 921 |
+
# log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
| 922 |
+
return logits
|
| 923 |
+
return func
|
| 924 |
+
|
| 925 |
+
@staticmethod
|
| 926 |
+
def gen_forward(lambdas, loss_normalization=False):
|
| 927 |
+
def func(
|
| 928 |
+
self,
|
| 929 |
+
input_values: Optional[torch.Tensor],
|
| 930 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 931 |
+
output_attentions: Optional[bool] = None,
|
| 932 |
+
output_hidden_states: Optional[bool] = None,
|
| 933 |
+
return_dict: Optional[bool] = None,
|
| 934 |
+
labels: Optional[torch.Tensor] = None,
|
| 935 |
+
) -> Union[Tuple, CausalLMOutput]:
|
| 936 |
+
r"""
|
| 937 |
+
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
|
| 938 |
+
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
|
| 939 |
+
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
|
| 940 |
+
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
| 941 |
+
config.vocab_size - 1]`.
|
| 942 |
+
"""
|
| 943 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 944 |
+
|
| 945 |
+
if labels is not None and labels.max() >= self.config.vocab_size:
|
| 946 |
+
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
| 947 |
+
|
| 948 |
+
outputs = self.wav2vec2(
|
| 949 |
+
input_values,
|
| 950 |
+
attention_mask=attention_mask,
|
| 951 |
+
output_attentions=output_attentions,
|
| 952 |
+
output_hidden_states=output_hidden_states,
|
| 953 |
+
return_dict=return_dict,
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
hidden_states = outputs[0]
|
| 957 |
+
hidden_states = self.dropout(hidden_states)
|
| 958 |
+
|
| 959 |
+
logits = self.lm_head(hidden_states)
|
| 960 |
+
|
| 961 |
+
loss = None
|
| 962 |
+
if labels is not None:
|
| 963 |
+
# retrieve loss input_lengths from attention_mask
|
| 964 |
+
attention_mask = (
|
| 965 |
+
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
| 966 |
+
)
|
| 967 |
+
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
| 968 |
+
shared_head = Wav2Vec2ForCTCConfig.gen_shared_head(self)
|
| 969 |
+
criterion = Wav2Vec2ForCTCConfig.gen_criterion(input_lengths, self.config.pad_token_id, self.config.ctc_zero_infinity)
|
| 970 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_states, logits, labels, loss_normalization) # NOTE: Apply TRP!
|
| 971 |
+
|
| 972 |
+
if not return_dict:
|
| 973 |
+
output = (logits,) + outputs[Wav2Vec2ForCTCConfig._HIDDEN_STATES_START_POSITION:]
|
| 974 |
+
return ((loss,) + output) if loss is not None else output
|
| 975 |
+
|
| 976 |
+
return CausalLMOutput(
|
| 977 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
| 978 |
+
)
|
| 979 |
+
return func
|
| 980 |
+
|
| 981 |
+
|
| 982 |
+
# MBart for Translation
|
| 983 |
+
class MBartForConditionalGenerationConfig(Config):
|
| 984 |
+
@staticmethod
|
| 985 |
+
def gen_criterion(vocab_size: int, top_k=1):
|
| 986 |
+
def func(logits, labels, mask=None):
|
| 987 |
+
"""
|
| 988 |
+
Args:
|
| 989 |
+
logits (Tensor): Logits tensor of shape [B, L, V].
|
| 990 |
+
labels (Tensor): Target labels of shape [B, L].
|
| 991 |
+
|
| 992 |
+
Returns:
|
| 993 |
+
loss (Tensor): Scalar tensor representing the loss.
|
| 994 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
| 995 |
+
"""
|
| 996 |
+
if mask is None:
|
| 997 |
+
mask = torch.ones_like(labels.view(-1), dtype=torch.float32, device=labels.device)
|
| 998 |
+
|
| 999 |
+
masked_losses = F.cross_entropy(logits.view(-1, vocab_size), labels.view(-1), reduction="none")
|
| 1000 |
+
loss = torch.sum(mask * masked_losses) / (torch.sum(mask) + 1e-6)
|
| 1001 |
+
|
| 1002 |
+
with torch.no_grad():
|
| 1003 |
+
topk_values, topk_indices = torch.topk(logits.view(-1, vocab_size), top_k, dim=1)
|
| 1004 |
+
mask = mask * torch.eq(topk_indices, labels.view(-1, 1)).any(dim=1).to(logits.dtype)
|
| 1005 |
+
|
| 1006 |
+
return loss, mask
|
| 1007 |
+
return func
|
| 1008 |
+
|
| 1009 |
+
@staticmethod
|
| 1010 |
+
def gen_shared_head(self):
|
| 1011 |
+
def func(hidden_states):
|
| 1012 |
+
"""
|
| 1013 |
+
Args:
|
| 1014 |
+
hidden_states (Tensor): Hidden States of shape [B, L, hidden_units].
|
| 1015 |
+
|
| 1016 |
+
Returns:
|
| 1017 |
+
logits (Tensor): Logits tensor of shape [B, L].
|
| 1018 |
+
"""
|
| 1019 |
+
logits = self.lm_head(hidden_states) + self.final_logits_bias
|
| 1020 |
+
return logits
|
| 1021 |
+
return func
|
| 1022 |
+
|
| 1023 |
+
@staticmethod
|
| 1024 |
+
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
|
| 1025 |
+
"""
|
| 1026 |
+
Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
|
| 1027 |
+
have a single `decoder_start_token_id` in contrast to other Bart-like models.
|
| 1028 |
+
"""
|
| 1029 |
+
prev_output_tokens = input_ids.clone()
|
| 1030 |
+
|
| 1031 |
+
if pad_token_id is None:
|
| 1032 |
+
raise ValueError("self.model.config.pad_token_id has to be defined.")
|
| 1033 |
+
# replace possible -100 values in labels by `pad_token_id`
|
| 1034 |
+
prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)
|
| 1035 |
+
|
| 1036 |
+
index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
|
| 1037 |
+
decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
|
| 1038 |
+
prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
|
| 1039 |
+
prev_output_tokens[:, 0] = decoder_start_tokens
|
| 1040 |
+
|
| 1041 |
+
return prev_output_tokens
|
| 1042 |
+
|
| 1043 |
+
@staticmethod
|
| 1044 |
+
def gen_forward(lambdas, loss_normalization=False):
|
| 1045 |
+
def func(
|
| 1046 |
+
self,
|
| 1047 |
+
input_ids: torch.LongTensor = None,
|
| 1048 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1049 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
| 1050 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
| 1051 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1052 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
| 1053 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
| 1054 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 1055 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 1056 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1057 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1058 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1059 |
+
use_cache: Optional[bool] = None,
|
| 1060 |
+
output_attentions: Optional[bool] = None,
|
| 1061 |
+
output_hidden_states: Optional[bool] = None,
|
| 1062 |
+
return_dict: Optional[bool] = None,
|
| 1063 |
+
) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
|
| 1064 |
+
r"""
|
| 1065 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1066 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 1067 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 1068 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 1069 |
+
|
| 1070 |
+
Returns:
|
| 1071 |
+
|
| 1072 |
+
"""
|
| 1073 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1074 |
+
|
| 1075 |
+
if labels is not None:
|
| 1076 |
+
# if use_cache:
|
| 1077 |
+
# logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
| 1078 |
+
use_cache = False
|
| 1079 |
+
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
| 1080 |
+
decoder_input_ids = MBartForConditionalGenerationConfig.shift_tokens_right(labels, self.config.pad_token_id)
|
| 1081 |
+
|
| 1082 |
+
outputs = self.model(
|
| 1083 |
+
input_ids,
|
| 1084 |
+
attention_mask=attention_mask,
|
| 1085 |
+
decoder_input_ids=decoder_input_ids,
|
| 1086 |
+
encoder_outputs=encoder_outputs,
|
| 1087 |
+
decoder_attention_mask=decoder_attention_mask,
|
| 1088 |
+
head_mask=head_mask,
|
| 1089 |
+
decoder_head_mask=decoder_head_mask,
|
| 1090 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
| 1091 |
+
past_key_values=past_key_values,
|
| 1092 |
+
inputs_embeds=inputs_embeds,
|
| 1093 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
| 1094 |
+
use_cache=use_cache,
|
| 1095 |
+
output_attentions=output_attentions,
|
| 1096 |
+
output_hidden_states=output_hidden_states,
|
| 1097 |
+
return_dict=return_dict,
|
| 1098 |
+
)
|
| 1099 |
+
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
|
| 1100 |
+
|
| 1101 |
+
masked_lm_loss = None
|
| 1102 |
+
if labels is not None:
|
| 1103 |
+
shared_head = MBartForConditionalGenerationConfig.gen_shared_head(self)
|
| 1104 |
+
criterion = MBartForConditionalGenerationConfig.gen_criterion(self.config.vocab_size)
|
| 1105 |
+
masked_lm_loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, outputs[0], lm_logits, labels, loss_normalization)
|
| 1106 |
+
|
| 1107 |
+
if not return_dict:
|
| 1108 |
+
output = (lm_logits,) + outputs[1:]
|
| 1109 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 1110 |
+
|
| 1111 |
+
return Seq2SeqLMOutput(
|
| 1112 |
+
loss=masked_lm_loss,
|
| 1113 |
+
logits=lm_logits,
|
| 1114 |
+
past_key_values=outputs.past_key_values,
|
| 1115 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
| 1116 |
+
decoder_attentions=outputs.decoder_attentions,
|
| 1117 |
+
cross_attentions=outputs.cross_attentions,
|
| 1118 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
| 1119 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
| 1120 |
+
encoder_attentions=outputs.encoder_attentions,
|
| 1121 |
+
)
|
| 1122 |
+
return func
|
| 1123 |
+
|
| 1124 |
+
|
| 1125 |
+
def apply_trp(model, depths: int, p: float, lambdas: List[float], **kwargs):
|
| 1126 |
+
if isinstance(model, transformers.Wav2Vec2ForSequenceClassification):
|
| 1127 |
+
print("✅ Applying TRP to Wav2Vec2 for Audio Classification...")
|
| 1128 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
|
| 1129 |
+
model.forward = types.MethodType(Wav2Vec2ForSequenceClassificationConfig.gen_forward(lambdas, False), model)
|
| 1130 |
+
elif isinstance(model, MobileNetV2):
|
| 1131 |
+
print("✅ Applying TRP to MobileNetV2 for Image Classification...")
|
| 1132 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1280, p) for _ in lambdas])
|
| 1133 |
+
model.forward = types.MethodType(MobileNetV2Config.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
|
| 1134 |
+
elif isinstance(model, ResNet):
|
| 1135 |
+
print("✅ Applying TRP to ResNet for Image Classification...")
|
| 1136 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 2048, p) for _ in lambdas])
|
| 1137 |
+
model.forward = types.MethodType(ResNetConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
|
| 1138 |
+
elif isinstance(model, EfficientNet):
|
| 1139 |
+
print("✅ Applying TRP to EfficientNet for Image Classification...")
|
| 1140 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1280, p) for _ in lambdas])
|
| 1141 |
+
model.forward = types.MethodType(EfficientNetConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
|
| 1142 |
+
elif isinstance(model, VisionTransformer):
|
| 1143 |
+
print("✅ Applying TRP to VisionTransformer for Image Classification...")
|
| 1144 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
|
| 1145 |
+
model.forward = types.MethodType(VisionTransformerConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
|
| 1146 |
+
elif isinstance(model, transformers.BertForQuestionAnswering):
|
| 1147 |
+
print("✅ Applying TRP to Bert for Question Answering...")
|
| 1148 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
|
| 1149 |
+
model.forward = types.MethodType(BertForQuestionAnsweringConfig.gen_forward(lambdas, True, 1), model)
|
| 1150 |
+
elif isinstance(model, FCN):
|
| 1151 |
+
print("✅ Applying TRP to FCN for Semantic Segmentation...")
|
| 1152 |
+
model.out_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 2048, p, dim=1) for _ in lambdas])
|
| 1153 |
+
model.aux_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p, dim=1) for _ in lambdas])
|
| 1154 |
+
model.forward = types.MethodType(FCNConfig.gen_forward(lambdas, True, 1), model)
|
| 1155 |
+
elif isinstance(model, DeepLabV3):
|
| 1156 |
+
print("✅ Applying TRP to DeepLabV3 for Semantic Segmentation...")
|
| 1157 |
+
model.out_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 2048, p, dim=1) for _ in lambdas])
|
| 1158 |
+
model.aux_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p, dim=1) for _ in lambdas])
|
| 1159 |
+
model.forward = types.MethodType(DeepLabV3Config.gen_forward(lambdas, True, 1), model)
|
| 1160 |
+
elif isinstance(model, transformers.BertForSequenceClassification):
|
| 1161 |
+
print("✅ Applying TRP to Bert for Text Classification...")
|
| 1162 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
|
| 1163 |
+
model.forward = types.MethodType(BertForSequenceClassificationConfig.gen_forward(lambdas, False), model)
|
| 1164 |
+
elif isinstance(model, transformers.RobertaForSequenceClassification):
|
| 1165 |
+
print("✅ Applying TRP to Roberta for Text Classification...")
|
| 1166 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
|
| 1167 |
+
model.forward = types.MethodType(RobertaForSequenceClassificationConfig.gen_forward(lambdas, False), model)
|
| 1168 |
+
elif isinstance(model, transformers.Wav2Vec2ForCTC):
|
| 1169 |
+
print("✅ Applying TRP to Wav2Vec2 for Speech Recognition...")
|
| 1170 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p) for _ in lambdas])
|
| 1171 |
+
model.forward = types.MethodType(Wav2Vec2ForCTCConfig.gen_forward(lambdas, False), model)
|
| 1172 |
+
elif isinstance(model, transformers.MBartForConditionalGeneration):
|
| 1173 |
+
print("✅ Applying TRP to MBart for Translation...")
|
| 1174 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p) for _ in lambdas])
|
| 1175 |
+
model.forward = types.MethodType(MBartForConditionalGenerationConfig.gen_forward(lambdas, False), model)
|
| 1176 |
+
else:
|
| 1177 |
+
torch._assert(
|
| 1178 |
+
isinstance(model, transformers.Wav2Vec2ForSequenceClassification),
|
| 1179 |
+
"The model should be an object of [`Wav2Vec2ForSequenceClassification`].")
|
| 1180 |
+
|
| 1181 |
+
return model
|
hpo-examples/image-classification/utils.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import datetime
|
| 3 |
+
import errno
|
| 4 |
+
import hashlib
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from collections import defaultdict, deque, OrderedDict
|
| 8 |
+
from typing import List, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SmoothedValue:
|
| 15 |
+
"""Track a series of values and provide access to smoothed values over a
|
| 16 |
+
window or the global series average.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, window_size=20, fmt=None):
|
| 20 |
+
if fmt is None:
|
| 21 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
| 22 |
+
self.deque = deque(maxlen=window_size)
|
| 23 |
+
self.total = 0.0
|
| 24 |
+
self.count = 0
|
| 25 |
+
self.fmt = fmt
|
| 26 |
+
|
| 27 |
+
def update(self, value, n=1):
|
| 28 |
+
self.deque.append(value)
|
| 29 |
+
self.count += n
|
| 30 |
+
self.total += value * n
|
| 31 |
+
|
| 32 |
+
def synchronize_between_processes(self):
|
| 33 |
+
"""
|
| 34 |
+
Warning: does not synchronize the deque!
|
| 35 |
+
"""
|
| 36 |
+
t = reduce_across_processes([self.count, self.total])
|
| 37 |
+
t = t.tolist()
|
| 38 |
+
self.count = int(t[0])
|
| 39 |
+
self.total = t[1]
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def median(self):
|
| 43 |
+
d = torch.tensor(list(self.deque))
|
| 44 |
+
return d.median().item()
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
def avg(self):
|
| 48 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 49 |
+
return d.mean().item()
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def global_avg(self):
|
| 53 |
+
return self.total / self.count
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def max(self):
|
| 57 |
+
return max(self.deque)
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def value(self):
|
| 61 |
+
return self.deque[-1]
|
| 62 |
+
|
| 63 |
+
def __str__(self):
|
| 64 |
+
return self.fmt.format(
|
| 65 |
+
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class MetricLogger:
|
| 70 |
+
def __init__(self, delimiter="\t"):
|
| 71 |
+
self.meters = defaultdict(SmoothedValue)
|
| 72 |
+
self.delimiter = delimiter
|
| 73 |
+
|
| 74 |
+
def update(self, **kwargs):
|
| 75 |
+
for k, v in kwargs.items():
|
| 76 |
+
if isinstance(v, torch.Tensor):
|
| 77 |
+
v = v.item()
|
| 78 |
+
assert isinstance(v, (float, int))
|
| 79 |
+
self.meters[k].update(v)
|
| 80 |
+
|
| 81 |
+
def __getattr__(self, attr):
|
| 82 |
+
if attr in self.meters:
|
| 83 |
+
return self.meters[attr]
|
| 84 |
+
if attr in self.__dict__:
|
| 85 |
+
return self.__dict__[attr]
|
| 86 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
|
| 87 |
+
|
| 88 |
+
def __str__(self):
|
| 89 |
+
loss_str = []
|
| 90 |
+
for name, meter in self.meters.items():
|
| 91 |
+
loss_str.append(f"{name}: {str(meter)}")
|
| 92 |
+
return self.delimiter.join(loss_str)
|
| 93 |
+
|
| 94 |
+
def synchronize_between_processes(self):
|
| 95 |
+
for meter in self.meters.values():
|
| 96 |
+
meter.synchronize_between_processes()
|
| 97 |
+
|
| 98 |
+
def add_meter(self, name, meter):
|
| 99 |
+
self.meters[name] = meter
|
| 100 |
+
|
| 101 |
+
def log_every(self, iterable, print_freq, header=None):
|
| 102 |
+
i = 0
|
| 103 |
+
if not header:
|
| 104 |
+
header = ""
|
| 105 |
+
start_time = time.time()
|
| 106 |
+
end = time.time()
|
| 107 |
+
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
| 108 |
+
data_time = SmoothedValue(fmt="{avg:.4f}")
|
| 109 |
+
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
| 110 |
+
if torch.cuda.is_available():
|
| 111 |
+
log_msg = self.delimiter.join(
|
| 112 |
+
[
|
| 113 |
+
header,
|
| 114 |
+
"[{0" + space_fmt + "}/{1}]",
|
| 115 |
+
"eta: {eta}",
|
| 116 |
+
"{meters}",
|
| 117 |
+
"time: {time}",
|
| 118 |
+
"data: {data}",
|
| 119 |
+
"max mem: {memory:.0f}",
|
| 120 |
+
]
|
| 121 |
+
)
|
| 122 |
+
else:
|
| 123 |
+
log_msg = self.delimiter.join(
|
| 124 |
+
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
|
| 125 |
+
)
|
| 126 |
+
MB = 1024.0 * 1024.0
|
| 127 |
+
for obj in iterable:
|
| 128 |
+
data_time.update(time.time() - end)
|
| 129 |
+
yield obj
|
| 130 |
+
iter_time.update(time.time() - end)
|
| 131 |
+
if i % print_freq == 0:
|
| 132 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
| 133 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 134 |
+
if torch.cuda.is_available():
|
| 135 |
+
print(
|
| 136 |
+
log_msg.format(
|
| 137 |
+
i,
|
| 138 |
+
len(iterable),
|
| 139 |
+
eta=eta_string,
|
| 140 |
+
meters=str(self),
|
| 141 |
+
time=str(iter_time),
|
| 142 |
+
data=str(data_time),
|
| 143 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
| 144 |
+
)
|
| 145 |
+
)
|
| 146 |
+
else:
|
| 147 |
+
print(
|
| 148 |
+
log_msg.format(
|
| 149 |
+
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
|
| 150 |
+
)
|
| 151 |
+
)
|
| 152 |
+
i += 1
|
| 153 |
+
end = time.time()
|
| 154 |
+
total_time = time.time() - start_time
|
| 155 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 156 |
+
print(f"{header} Total time: {total_time_str}")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
|
| 160 |
+
"""Maintains moving averages of model parameters using an exponential decay.
|
| 161 |
+
``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
|
| 162 |
+
`torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
|
| 163 |
+
is used to compute the EMA.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def __init__(self, model, decay, device="cpu"):
|
| 167 |
+
def ema_avg(avg_model_param, model_param, num_averaged):
|
| 168 |
+
return decay * avg_model_param + (1 - decay) * model_param
|
| 169 |
+
|
| 170 |
+
super().__init__(model, device, ema_avg, use_buffers=True)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def accuracy(output, target, topk=(1,)):
|
| 174 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
| 175 |
+
with torch.inference_mode():
|
| 176 |
+
maxk = max(topk)
|
| 177 |
+
batch_size = target.size(0)
|
| 178 |
+
if target.ndim == 2:
|
| 179 |
+
target = target.max(dim=1)[1]
|
| 180 |
+
|
| 181 |
+
_, pred = output.topk(maxk, 1, True, True)
|
| 182 |
+
pred = pred.t()
|
| 183 |
+
correct = pred.eq(target[None])
|
| 184 |
+
|
| 185 |
+
res = []
|
| 186 |
+
for k in topk:
|
| 187 |
+
correct_k = correct[:k].flatten().sum(dtype=torch.float32)
|
| 188 |
+
res.append(correct_k * (100.0 / batch_size))
|
| 189 |
+
return res
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def mkdir(path):
|
| 193 |
+
try:
|
| 194 |
+
os.makedirs(path)
|
| 195 |
+
except OSError as e:
|
| 196 |
+
if e.errno != errno.EEXIST:
|
| 197 |
+
raise
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def setup_for_distributed(is_master):
|
| 201 |
+
"""
|
| 202 |
+
This function disables printing when not in master process
|
| 203 |
+
"""
|
| 204 |
+
import builtins as __builtin__
|
| 205 |
+
|
| 206 |
+
builtin_print = __builtin__.print
|
| 207 |
+
|
| 208 |
+
def print(*args, **kwargs):
|
| 209 |
+
force = kwargs.pop("force", False)
|
| 210 |
+
if is_master or force:
|
| 211 |
+
builtin_print(*args, **kwargs)
|
| 212 |
+
|
| 213 |
+
__builtin__.print = print
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def is_dist_avail_and_initialized():
|
| 217 |
+
if not dist.is_available():
|
| 218 |
+
return False
|
| 219 |
+
if not dist.is_initialized():
|
| 220 |
+
return False
|
| 221 |
+
return True
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def get_world_size():
|
| 225 |
+
if not is_dist_avail_and_initialized():
|
| 226 |
+
return 1
|
| 227 |
+
return dist.get_world_size()
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def get_rank():
|
| 231 |
+
if not is_dist_avail_and_initialized():
|
| 232 |
+
return 0
|
| 233 |
+
return dist.get_rank()
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def is_main_process():
|
| 237 |
+
return get_rank() == 0
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def save_on_master(*args, **kwargs):
|
| 241 |
+
if is_main_process():
|
| 242 |
+
torch.save(*args, **kwargs)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def init_distributed_mode(args):
|
| 246 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
| 247 |
+
args.rank = int(os.environ["RANK"])
|
| 248 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
| 249 |
+
args.gpu = int(os.environ["LOCAL_RANK"])
|
| 250 |
+
elif "SLURM_PROCID" in os.environ:
|
| 251 |
+
args.rank = int(os.environ["SLURM_PROCID"])
|
| 252 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
| 253 |
+
elif hasattr(args, "rank"):
|
| 254 |
+
pass
|
| 255 |
+
else:
|
| 256 |
+
print("Not using distributed mode")
|
| 257 |
+
args.distributed = False
|
| 258 |
+
return
|
| 259 |
+
|
| 260 |
+
args.distributed = True
|
| 261 |
+
|
| 262 |
+
torch.cuda.set_device(args.gpu)
|
| 263 |
+
args.dist_backend = "nccl"
|
| 264 |
+
print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
|
| 265 |
+
torch.distributed.init_process_group(
|
| 266 |
+
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
| 267 |
+
)
|
| 268 |
+
torch.distributed.barrier()
|
| 269 |
+
setup_for_distributed(args.rank == 0)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def average_checkpoints(inputs):
|
| 273 |
+
"""Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from:
|
| 274 |
+
https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
inputs (List[str]): An iterable of string paths of checkpoints to load from.
|
| 278 |
+
Returns:
|
| 279 |
+
A dict of string keys mapping to various values. The 'model' key
|
| 280 |
+
from the returned dict should correspond to an OrderedDict mapping
|
| 281 |
+
string parameter names to torch Tensors.
|
| 282 |
+
"""
|
| 283 |
+
params_dict = OrderedDict()
|
| 284 |
+
params_keys = None
|
| 285 |
+
new_state = None
|
| 286 |
+
num_models = len(inputs)
|
| 287 |
+
for fpath in inputs:
|
| 288 |
+
with open(fpath, "rb") as f:
|
| 289 |
+
state = torch.load(
|
| 290 |
+
f,
|
| 291 |
+
map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")),
|
| 292 |
+
)
|
| 293 |
+
# Copies over the settings from the first checkpoint
|
| 294 |
+
if new_state is None:
|
| 295 |
+
new_state = state
|
| 296 |
+
model_params = state["model"]
|
| 297 |
+
model_params_keys = list(model_params.keys())
|
| 298 |
+
if params_keys is None:
|
| 299 |
+
params_keys = model_params_keys
|
| 300 |
+
elif params_keys != model_params_keys:
|
| 301 |
+
raise KeyError(
|
| 302 |
+
f"For checkpoint {f}, expected list of params: {params_keys}, but found: {model_params_keys}"
|
| 303 |
+
)
|
| 304 |
+
for k in params_keys:
|
| 305 |
+
p = model_params[k]
|
| 306 |
+
if isinstance(p, torch.HalfTensor):
|
| 307 |
+
p = p.float()
|
| 308 |
+
if k not in params_dict:
|
| 309 |
+
params_dict[k] = p.clone()
|
| 310 |
+
# NOTE: clone() is needed in case of p is a shared parameter
|
| 311 |
+
else:
|
| 312 |
+
params_dict[k] += p
|
| 313 |
+
averaged_params = OrderedDict()
|
| 314 |
+
for k, v in params_dict.items():
|
| 315 |
+
averaged_params[k] = v
|
| 316 |
+
if averaged_params[k].is_floating_point():
|
| 317 |
+
averaged_params[k].div_(num_models)
|
| 318 |
+
else:
|
| 319 |
+
averaged_params[k] //= num_models
|
| 320 |
+
new_state["model"] = averaged_params
|
| 321 |
+
return new_state
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True):
|
| 325 |
+
"""
|
| 326 |
+
This method can be used to prepare weights files for new models. It receives as
|
| 327 |
+
input a model architecture and a checkpoint from the training script and produces
|
| 328 |
+
a file with the weights ready for release.
|
| 329 |
+
|
| 330 |
+
Examples:
|
| 331 |
+
from torchvision import models as M
|
| 332 |
+
|
| 333 |
+
# Classification
|
| 334 |
+
model = M.mobilenet_v3_large(weights=None)
|
| 335 |
+
print(store_model_weights(model, './class.pth'))
|
| 336 |
+
|
| 337 |
+
# Quantized Classification
|
| 338 |
+
model = M.quantization.mobilenet_v3_large(weights=None, quantize=False)
|
| 339 |
+
model.fuse_model(is_qat=True)
|
| 340 |
+
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
|
| 341 |
+
_ = torch.ao.quantization.prepare_qat(model, inplace=True)
|
| 342 |
+
print(store_model_weights(model, './qat.pth'))
|
| 343 |
+
|
| 344 |
+
# Object Detection
|
| 345 |
+
model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None)
|
| 346 |
+
print(store_model_weights(model, './obj.pth'))
|
| 347 |
+
|
| 348 |
+
# Segmentation
|
| 349 |
+
model = M.segmentation.deeplabv3_mobilenet_v3_large(weights=None, weights_backbone=None, aux_loss=True)
|
| 350 |
+
print(store_model_weights(model, './segm.pth', strict=False))
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes.
|
| 354 |
+
checkpoint_path (str): The path of the checkpoint we will load.
|
| 355 |
+
checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored.
|
| 356 |
+
Default: "model".
|
| 357 |
+
strict (bool): whether to strictly enforce that the keys
|
| 358 |
+
in :attr:`state_dict` match the keys returned by this module's
|
| 359 |
+
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
|
| 360 |
+
|
| 361 |
+
Returns:
|
| 362 |
+
output_path (str): The location where the weights are saved.
|
| 363 |
+
"""
|
| 364 |
+
# Store the new model next to the checkpoint_path
|
| 365 |
+
checkpoint_path = os.path.abspath(checkpoint_path)
|
| 366 |
+
output_dir = os.path.dirname(checkpoint_path)
|
| 367 |
+
|
| 368 |
+
# Deep copy to avoid side-effects on the model object.
|
| 369 |
+
model = copy.deepcopy(model)
|
| 370 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 371 |
+
|
| 372 |
+
# Load the weights to the model to validate that everything works
|
| 373 |
+
# and remove unnecessary weights (such as auxiliaries, etc)
|
| 374 |
+
if checkpoint_key == "model_ema":
|
| 375 |
+
del checkpoint[checkpoint_key]["n_averaged"]
|
| 376 |
+
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.")
|
| 377 |
+
model.load_state_dict(checkpoint[checkpoint_key], strict=strict)
|
| 378 |
+
|
| 379 |
+
tmp_path = os.path.join(output_dir, str(model.__hash__()))
|
| 380 |
+
torch.save(model.state_dict(), tmp_path)
|
| 381 |
+
|
| 382 |
+
sha256_hash = hashlib.sha256()
|
| 383 |
+
with open(tmp_path, "rb") as f:
|
| 384 |
+
# Read and update hash string value in blocks of 4K
|
| 385 |
+
for byte_block in iter(lambda: f.read(4096), b""):
|
| 386 |
+
sha256_hash.update(byte_block)
|
| 387 |
+
hh = sha256_hash.hexdigest()
|
| 388 |
+
|
| 389 |
+
output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth")
|
| 390 |
+
os.replace(tmp_path, output_path)
|
| 391 |
+
|
| 392 |
+
return output_path
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def reduce_across_processes(val):
|
| 396 |
+
if not is_dist_avail_and_initialized():
|
| 397 |
+
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
|
| 398 |
+
return torch.tensor(val)
|
| 399 |
+
|
| 400 |
+
t = torch.tensor(val, device="cuda")
|
| 401 |
+
dist.barrier()
|
| 402 |
+
dist.all_reduce(t)
|
| 403 |
+
return t
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def set_weight_decay(
|
| 407 |
+
model: torch.nn.Module,
|
| 408 |
+
weight_decay: float,
|
| 409 |
+
norm_weight_decay: Optional[float] = None,
|
| 410 |
+
norm_classes: Optional[List[type]] = None,
|
| 411 |
+
custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None,
|
| 412 |
+
):
|
| 413 |
+
if not norm_classes:
|
| 414 |
+
norm_classes = [
|
| 415 |
+
torch.nn.modules.batchnorm._BatchNorm,
|
| 416 |
+
torch.nn.LayerNorm,
|
| 417 |
+
torch.nn.GroupNorm,
|
| 418 |
+
torch.nn.modules.instancenorm._InstanceNorm,
|
| 419 |
+
torch.nn.LocalResponseNorm,
|
| 420 |
+
]
|
| 421 |
+
norm_classes = tuple(norm_classes)
|
| 422 |
+
|
| 423 |
+
params = {
|
| 424 |
+
"other": [],
|
| 425 |
+
"norm": [],
|
| 426 |
+
}
|
| 427 |
+
params_weight_decay = {
|
| 428 |
+
"other": weight_decay,
|
| 429 |
+
"norm": norm_weight_decay,
|
| 430 |
+
}
|
| 431 |
+
custom_keys = []
|
| 432 |
+
if custom_keys_weight_decay is not None:
|
| 433 |
+
for key, weight_decay in custom_keys_weight_decay:
|
| 434 |
+
params[key] = []
|
| 435 |
+
params_weight_decay[key] = weight_decay
|
| 436 |
+
custom_keys.append(key)
|
| 437 |
+
|
| 438 |
+
def _add_params(module, prefix=""):
|
| 439 |
+
for name, p in module.named_parameters(recurse=False):
|
| 440 |
+
if not p.requires_grad:
|
| 441 |
+
continue
|
| 442 |
+
is_custom_key = False
|
| 443 |
+
for key in custom_keys:
|
| 444 |
+
target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name
|
| 445 |
+
if key == target_name:
|
| 446 |
+
params[key].append(p)
|
| 447 |
+
is_custom_key = True
|
| 448 |
+
break
|
| 449 |
+
if not is_custom_key:
|
| 450 |
+
if norm_weight_decay is not None and isinstance(module, norm_classes):
|
| 451 |
+
params["norm"].append(p)
|
| 452 |
+
else:
|
| 453 |
+
params["other"].append(p)
|
| 454 |
+
|
| 455 |
+
for child_name, child_module in module.named_children():
|
| 456 |
+
child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
|
| 457 |
+
_add_params(child_module, prefix=child_prefix)
|
| 458 |
+
|
| 459 |
+
_add_params(model)
|
| 460 |
+
|
| 461 |
+
param_groups = []
|
| 462 |
+
for key in params:
|
| 463 |
+
if len(params[key]) > 0:
|
| 464 |
+
param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
|
| 465 |
+
return param_groups
|
hpo-examples/image-classification/vit_b_16/model_4.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:29f1bd991a3f27f7982e8700f31332cb94c4783b83e240fc71f1eca03d4eb468
|
| 3 |
+
size 1053172110
|
hpo-examples/question-answering/qa/README.md
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
license: apache-2.0
|
| 4 |
+
base_model: google-bert/bert-base-uncased
|
| 5 |
+
tags:
|
| 6 |
+
- generated_from_trainer
|
| 7 |
+
datasets:
|
| 8 |
+
- squad
|
| 9 |
+
model-index:
|
| 10 |
+
- name: baseline
|
| 11 |
+
results: []
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
| 15 |
+
should probably proofread and complete it, then remove this comment. -->
|
| 16 |
+
|
| 17 |
+
# baseline
|
| 18 |
+
|
| 19 |
+
This model is a fine-tuned version of [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) on the squad dataset.
|
| 20 |
+
|
| 21 |
+
## Model description
|
| 22 |
+
|
| 23 |
+
More information needed
|
| 24 |
+
|
| 25 |
+
## Intended uses & limitations
|
| 26 |
+
|
| 27 |
+
More information needed
|
| 28 |
+
|
| 29 |
+
## Training and evaluation data
|
| 30 |
+
|
| 31 |
+
More information needed
|
| 32 |
+
|
| 33 |
+
## Training procedure
|
| 34 |
+
|
| 35 |
+
### Training hyperparameters
|
| 36 |
+
|
| 37 |
+
The following hyperparameters were used during training:
|
| 38 |
+
- learning_rate: 3e-05
|
| 39 |
+
- train_batch_size: 12
|
| 40 |
+
- eval_batch_size: 8
|
| 41 |
+
- seed: 42
|
| 42 |
+
- optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
|
| 43 |
+
- lr_scheduler_type: linear
|
| 44 |
+
- num_epochs: 2.0
|
| 45 |
+
|
| 46 |
+
### Training results
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
### Framework versions
|
| 51 |
+
|
| 52 |
+
- Transformers 4.49.0
|
| 53 |
+
- Pytorch 2.6.0+cu118
|
| 54 |
+
- Datasets 3.3.1
|
| 55 |
+
- Tokenizers 0.21.0
|
hpo-examples/question-answering/qa/all_results.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"epoch": 2.0,
|
| 3 |
+
"eval_exact_match": 81.49479659413434,
|
| 4 |
+
"eval_f1": 88.62945564424126,
|
| 5 |
+
"eval_runtime": 61.0301,
|
| 6 |
+
"eval_samples": 10784,
|
| 7 |
+
"eval_samples_per_second": 176.7,
|
| 8 |
+
"eval_steps_per_second": 22.087,
|
| 9 |
+
"total_flos": 3.541929151120589e+16,
|
| 10 |
+
"train_loss": 1.148573803161563,
|
| 11 |
+
"train_runtime": 3245.3985,
|
| 12 |
+
"train_samples": 88524,
|
| 13 |
+
"train_samples_per_second": 54.554,
|
| 14 |
+
"train_steps_per_second": 4.546
|
| 15 |
+
}
|
hpo-examples/question-answering/qa/config.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "google-bert/bert-base-uncased",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"BertForQuestionAnswering"
|
| 5 |
+
],
|
| 6 |
+
"attention_probs_dropout_prob": 0.1,
|
| 7 |
+
"classifier_dropout": null,
|
| 8 |
+
"gradient_checkpointing": false,
|
| 9 |
+
"hidden_act": "gelu",
|
| 10 |
+
"hidden_dropout_prob": 0.1,
|
| 11 |
+
"hidden_size": 768,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": 3072,
|
| 14 |
+
"layer_norm_eps": 1e-12,
|
| 15 |
+
"max_position_embeddings": 512,
|
| 16 |
+
"model_type": "bert",
|
| 17 |
+
"num_attention_heads": 12,
|
| 18 |
+
"num_hidden_layers": 12,
|
| 19 |
+
"pad_token_id": 0,
|
| 20 |
+
"position_embedding_type": "absolute",
|
| 21 |
+
"torch_dtype": "float32",
|
| 22 |
+
"transformers_version": "4.49.0",
|
| 23 |
+
"type_vocab_size": 2,
|
| 24 |
+
"use_cache": true,
|
| 25 |
+
"vocab_size": 30522
|
| 26 |
+
}
|
hpo-examples/question-answering/qa/eval_nbest_predictions.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8b8d44953cbe0ce20d1d1b62b72e7adba18bf1dc81d055492e22bfa21ff46657
|
| 3 |
+
size 49596120
|
hpo-examples/question-answering/qa/eval_predictions.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
hpo-examples/question-answering/qa/eval_results.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"epoch": 2.0,
|
| 3 |
+
"eval_exact_match": 81.49479659413434,
|
| 4 |
+
"eval_f1": 88.62945564424126,
|
| 5 |
+
"eval_runtime": 61.0301,
|
| 6 |
+
"eval_samples": 10784,
|
| 7 |
+
"eval_samples_per_second": 176.7,
|
| 8 |
+
"eval_steps_per_second": 22.087
|
| 9 |
+
}
|
hpo-examples/question-answering/qa/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:38003bd65e4bfa70dd16886f29af7ab00d1aa0ae4de191b0a7de4d7883d17dde
|
| 3 |
+
size 442683784
|
hpo-examples/question-answering/qa/runs/May15_03-24-14_cs-Precision-7960-Tower/events.out.tfevents.1747293859.cs-Precision-7960-Tower.147971.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:36bfca6273a2422943de7b634cf75efd69b8e92079abe84df9e9c9e026d497f6
|
| 3 |
+
size 11535
|
hpo-examples/question-answering/qa/runs/May15_03-24-14_cs-Precision-7960-Tower/events.out.tfevents.1747297197.cs-Precision-7960-Tower.147971.1
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:259c79a03ba9c522b1fd728e92dae5cfc31c6cd73b2377d124749c83a0163910
|
| 3 |
+
size 412
|
hpo-examples/question-answering/qa/special_tokens_map.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cls_token": "[CLS]",
|
| 3 |
+
"mask_token": "[MASK]",
|
| 4 |
+
"pad_token": "[PAD]",
|
| 5 |
+
"sep_token": "[SEP]",
|
| 6 |
+
"unk_token": "[UNK]"
|
| 7 |
+
}
|
hpo-examples/question-answering/qa/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
hpo-examples/question-answering/qa/tokenizer_config.json
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "[PAD]",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"100": {
|
| 12 |
+
"content": "[UNK]",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"101": {
|
| 20 |
+
"content": "[CLS]",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"102": {
|
| 28 |
+
"content": "[SEP]",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"103": {
|
| 36 |
+
"content": "[MASK]",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"clean_up_tokenization_spaces": false,
|
| 45 |
+
"cls_token": "[CLS]",
|
| 46 |
+
"do_lower_case": true,
|
| 47 |
+
"extra_special_tokens": {},
|
| 48 |
+
"mask_token": "[MASK]",
|
| 49 |
+
"model_max_length": 512,
|
| 50 |
+
"pad_token": "[PAD]",
|
| 51 |
+
"sep_token": "[SEP]",
|
| 52 |
+
"strip_accents": null,
|
| 53 |
+
"tokenize_chinese_chars": true,
|
| 54 |
+
"tokenizer_class": "BertTokenizer",
|
| 55 |
+
"unk_token": "[UNK]"
|
| 56 |
+
}
|
hpo-examples/question-answering/qa/train_results.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"epoch": 2.0,
|
| 3 |
+
"total_flos": 3.541929151120589e+16,
|
| 4 |
+
"train_loss": 1.148573803161563,
|
| 5 |
+
"train_runtime": 3245.3985,
|
| 6 |
+
"train_samples": 88524,
|
| 7 |
+
"train_samples_per_second": 54.554,
|
| 8 |
+
"train_steps_per_second": 4.546
|
| 9 |
+
}
|
hpo-examples/question-answering/qa/trainer_state.json
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_metric": null,
|
| 3 |
+
"best_model_checkpoint": null,
|
| 4 |
+
"epoch": 2.0,
|
| 5 |
+
"eval_steps": 500,
|
| 6 |
+
"global_step": 14754,
|
| 7 |
+
"is_hyper_param_search": false,
|
| 8 |
+
"is_local_process_zero": true,
|
| 9 |
+
"is_world_process_zero": true,
|
| 10 |
+
"log_history": [
|
| 11 |
+
{
|
| 12 |
+
"epoch": 0.06777822963264199,
|
| 13 |
+
"grad_norm": 31.397275924682617,
|
| 14 |
+
"learning_rate": 2.8983326555510372e-05,
|
| 15 |
+
"loss": 2.7299,
|
| 16 |
+
"step": 500
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"epoch": 0.13555645926528398,
|
| 20 |
+
"grad_norm": 25.8492431640625,
|
| 21 |
+
"learning_rate": 2.796665311102074e-05,
|
| 22 |
+
"loss": 1.752,
|
| 23 |
+
"step": 1000
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"epoch": 0.203334688897926,
|
| 27 |
+
"grad_norm": 29.627431869506836,
|
| 28 |
+
"learning_rate": 2.694997966653111e-05,
|
| 29 |
+
"loss": 1.5588,
|
| 30 |
+
"step": 1500
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"epoch": 0.27111291853056796,
|
| 34 |
+
"grad_norm": 21.147193908691406,
|
| 35 |
+
"learning_rate": 2.593330622204148e-05,
|
| 36 |
+
"loss": 1.5014,
|
| 37 |
+
"step": 2000
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"epoch": 0.33889114816321,
|
| 41 |
+
"grad_norm": 17.81966781616211,
|
| 42 |
+
"learning_rate": 2.491663277755185e-05,
|
| 43 |
+
"loss": 1.4768,
|
| 44 |
+
"step": 2500
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"epoch": 0.406669377795852,
|
| 48 |
+
"grad_norm": 20.26822853088379,
|
| 49 |
+
"learning_rate": 2.389995933306222e-05,
|
| 50 |
+
"loss": 1.4064,
|
| 51 |
+
"step": 3000
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"epoch": 0.47444760742849396,
|
| 55 |
+
"grad_norm": 16.216028213500977,
|
| 56 |
+
"learning_rate": 2.288328588857259e-05,
|
| 57 |
+
"loss": 1.3502,
|
| 58 |
+
"step": 3500
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"epoch": 0.5422258370611359,
|
| 62 |
+
"grad_norm": 17.930505752563477,
|
| 63 |
+
"learning_rate": 2.1866612444082963e-05,
|
| 64 |
+
"loss": 1.3101,
|
| 65 |
+
"step": 4000
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"epoch": 0.6100040666937779,
|
| 69 |
+
"grad_norm": 26.499574661254883,
|
| 70 |
+
"learning_rate": 2.084993899959333e-05,
|
| 71 |
+
"loss": 1.2922,
|
| 72 |
+
"step": 4500
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"epoch": 0.67778229632642,
|
| 76 |
+
"grad_norm": 26.83368492126465,
|
| 77 |
+
"learning_rate": 1.9833265555103702e-05,
|
| 78 |
+
"loss": 1.3053,
|
| 79 |
+
"step": 5000
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"epoch": 0.745560525959062,
|
| 83 |
+
"grad_norm": 22.85872459411621,
|
| 84 |
+
"learning_rate": 1.8816592110614073e-05,
|
| 85 |
+
"loss": 1.2555,
|
| 86 |
+
"step": 5500
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"epoch": 0.813338755591704,
|
| 90 |
+
"grad_norm": 23.48080825805664,
|
| 91 |
+
"learning_rate": 1.779991866612444e-05,
|
| 92 |
+
"loss": 1.2068,
|
| 93 |
+
"step": 6000
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"epoch": 0.8811169852243459,
|
| 97 |
+
"grad_norm": 20.919252395629883,
|
| 98 |
+
"learning_rate": 1.6783245221634812e-05,
|
| 99 |
+
"loss": 1.1991,
|
| 100 |
+
"step": 6500
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"epoch": 0.9488952148569879,
|
| 104 |
+
"grad_norm": 23.9005126953125,
|
| 105 |
+
"learning_rate": 1.576657177714518e-05,
|
| 106 |
+
"loss": 1.2156,
|
| 107 |
+
"step": 7000
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"epoch": 1.01667344448963,
|
| 111 |
+
"grad_norm": 22.660743713378906,
|
| 112 |
+
"learning_rate": 1.4749898332655551e-05,
|
| 113 |
+
"loss": 1.0827,
|
| 114 |
+
"step": 7500
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"epoch": 1.0844516741222718,
|
| 118 |
+
"grad_norm": 25.28419303894043,
|
| 119 |
+
"learning_rate": 1.373322488816592e-05,
|
| 120 |
+
"loss": 0.8481,
|
| 121 |
+
"step": 8000
|
| 122 |
+
},
|
| 123 |
+
{
|
| 124 |
+
"epoch": 1.152229903754914,
|
| 125 |
+
"grad_norm": 14.510698318481445,
|
| 126 |
+
"learning_rate": 1.271655144367629e-05,
|
| 127 |
+
"loss": 0.872,
|
| 128 |
+
"step": 8500
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"epoch": 1.2200081333875559,
|
| 132 |
+
"grad_norm": 29.12289810180664,
|
| 133 |
+
"learning_rate": 1.1699877999186661e-05,
|
| 134 |
+
"loss": 0.8375,
|
| 135 |
+
"step": 9000
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"epoch": 1.287786363020198,
|
| 139 |
+
"grad_norm": 19.038454055786133,
|
| 140 |
+
"learning_rate": 1.0683204554697033e-05,
|
| 141 |
+
"loss": 0.8464,
|
| 142 |
+
"step": 9500
|
| 143 |
+
},
|
| 144 |
+
{
|
| 145 |
+
"epoch": 1.35556459265284,
|
| 146 |
+
"grad_norm": 21.09101676940918,
|
| 147 |
+
"learning_rate": 9.666531110207402e-06,
|
| 148 |
+
"loss": 0.8746,
|
| 149 |
+
"step": 10000
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"epoch": 1.4233428222854818,
|
| 153 |
+
"grad_norm": 20.79250144958496,
|
| 154 |
+
"learning_rate": 8.649857665717772e-06,
|
| 155 |
+
"loss": 0.8776,
|
| 156 |
+
"step": 10500
|
| 157 |
+
},
|
| 158 |
+
{
|
| 159 |
+
"epoch": 1.491121051918124,
|
| 160 |
+
"grad_norm": 21.217571258544922,
|
| 161 |
+
"learning_rate": 7.633184221228141e-06,
|
| 162 |
+
"loss": 0.8523,
|
| 163 |
+
"step": 11000
|
| 164 |
+
},
|
| 165 |
+
{
|
| 166 |
+
"epoch": 1.5588992815507658,
|
| 167 |
+
"grad_norm": 15.557079315185547,
|
| 168 |
+
"learning_rate": 6.616510776738511e-06,
|
| 169 |
+
"loss": 0.8387,
|
| 170 |
+
"step": 11500
|
| 171 |
+
},
|
| 172 |
+
{
|
| 173 |
+
"epoch": 1.626677511183408,
|
| 174 |
+
"grad_norm": 14.53345012664795,
|
| 175 |
+
"learning_rate": 5.5998373322488825e-06,
|
| 176 |
+
"loss": 0.8377,
|
| 177 |
+
"step": 12000
|
| 178 |
+
},
|
| 179 |
+
{
|
| 180 |
+
"epoch": 1.6944557408160499,
|
| 181 |
+
"grad_norm": 26.921611785888672,
|
| 182 |
+
"learning_rate": 4.583163887759252e-06,
|
| 183 |
+
"loss": 0.8449,
|
| 184 |
+
"step": 12500
|
| 185 |
+
},
|
| 186 |
+
{
|
| 187 |
+
"epoch": 1.7622339704486918,
|
| 188 |
+
"grad_norm": 12.789366722106934,
|
| 189 |
+
"learning_rate": 3.566490443269622e-06,
|
| 190 |
+
"loss": 0.8547,
|
| 191 |
+
"step": 13000
|
| 192 |
+
},
|
| 193 |
+
{
|
| 194 |
+
"epoch": 1.830012200081334,
|
| 195 |
+
"grad_norm": 37.19759750366211,
|
| 196 |
+
"learning_rate": 2.549816998779992e-06,
|
| 197 |
+
"loss": 0.818,
|
| 198 |
+
"step": 13500
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"epoch": 1.8977904297139758,
|
| 202 |
+
"grad_norm": 14.62682819366455,
|
| 203 |
+
"learning_rate": 1.533143554290362e-06,
|
| 204 |
+
"loss": 0.8128,
|
| 205 |
+
"step": 14000
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"epoch": 1.965568659346618,
|
| 209 |
+
"grad_norm": 21.051790237426758,
|
| 210 |
+
"learning_rate": 5.164701098007319e-07,
|
| 211 |
+
"loss": 0.8115,
|
| 212 |
+
"step": 14500
|
| 213 |
+
},
|
| 214 |
+
{
|
| 215 |
+
"epoch": 2.0,
|
| 216 |
+
"step": 14754,
|
| 217 |
+
"total_flos": 3.541929151120589e+16,
|
| 218 |
+
"train_loss": 1.148573803161563,
|
| 219 |
+
"train_runtime": 3245.3985,
|
| 220 |
+
"train_samples_per_second": 54.554,
|
| 221 |
+
"train_steps_per_second": 4.546
|
| 222 |
+
}
|
| 223 |
+
],
|
| 224 |
+
"logging_steps": 500,
|
| 225 |
+
"max_steps": 14754,
|
| 226 |
+
"num_input_tokens_seen": 0,
|
| 227 |
+
"num_train_epochs": 2,
|
| 228 |
+
"save_steps": 500,
|
| 229 |
+
"stateful_callbacks": {
|
| 230 |
+
"TrainerControl": {
|
| 231 |
+
"args": {
|
| 232 |
+
"should_epoch_stop": false,
|
| 233 |
+
"should_evaluate": false,
|
| 234 |
+
"should_log": false,
|
| 235 |
+
"should_save": true,
|
| 236 |
+
"should_training_stop": true
|
| 237 |
+
},
|
| 238 |
+
"attributes": {}
|
| 239 |
+
}
|
| 240 |
+
},
|
| 241 |
+
"total_flos": 3.541929151120589e+16,
|
| 242 |
+
"train_batch_size": 12,
|
| 243 |
+
"trial_name": null,
|
| 244 |
+
"trial_params": null
|
| 245 |
+
}
|
hpo-examples/question-answering/qa/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fe8e61ba1ca1cb106ca9adca5e9262fa9a262238814728a69256855c78c32f51
|
| 3 |
+
size 5304
|
hpo-examples/question-answering/qa/vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
hpo-examples/question-answering/requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate >= 0.12.0
|
| 2 |
+
datasets >= 1.8.0
|
| 3 |
+
torch >= 1.3.0
|
| 4 |
+
evaluate
|