UniversalAlgorithmic commited on
Commit
09a2af4
·
verified ·
1 Parent(s): d8e11b0

Upload 178 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. hpo-examples/audio-classification/ac/README.md +88 -0
  3. hpo-examples/audio-classification/ac/all_results.json +13 -0
  4. hpo-examples/audio-classification/ac/config.json +147 -0
  5. hpo-examples/audio-classification/ac/eval_results.json +8 -0
  6. hpo-examples/audio-classification/ac/model.safetensors +3 -0
  7. hpo-examples/audio-classification/ac/preprocessor_config.json +9 -0
  8. hpo-examples/audio-classification/ac/runs/May15_03-06-03_cs-Precision-7960-Tower/events.out.tfevents.1747292768.cs-Precision-7960-Tower.146737.0 +3 -0
  9. hpo-examples/audio-classification/ac/runs/May15_03-06-03_cs-Precision-7960-Tower/events.out.tfevents.1747293535.cs-Precision-7960-Tower.146737.1 +3 -0
  10. hpo-examples/audio-classification/ac/train_results.json +8 -0
  11. hpo-examples/audio-classification/ac/trainer_state.json +1598 -0
  12. hpo-examples/audio-classification/ac/training_args.bin +3 -0
  13. hpo-examples/audio-classification/requirements.txt +5 -0
  14. hpo-examples/audio-classification/run.sh +30 -0
  15. hpo-examples/audio-classification/run_audio_classification.py +462 -0
  16. hpo-examples/audio-classification/trplib.py +1181 -0
  17. hpo-examples/image-classification/__pycache__/presets.cpython-310.pyc +0 -0
  18. hpo-examples/image-classification/__pycache__/sampler.cpython-310.pyc +0 -0
  19. hpo-examples/image-classification/__pycache__/transforms.cpython-310.pyc +0 -0
  20. hpo-examples/image-classification/__pycache__/trplib.cpython-310.pyc +0 -0
  21. hpo-examples/image-classification/__pycache__/utils.cpython-310.pyc +0 -0
  22. hpo-examples/image-classification/efficientnet_v2_m/model_7.pth +3 -0
  23. hpo-examples/image-classification/mobilenetv2/model_32.pth +3 -0
  24. hpo-examples/image-classification/presets.py +71 -0
  25. hpo-examples/image-classification/resnet50/model_35.pth +3 -0
  26. hpo-examples/image-classification/run.sh +49 -0
  27. hpo-examples/image-classification/sampler.py +62 -0
  28. hpo-examples/image-classification/train.py +524 -0
  29. hpo-examples/image-classification/train_quantization.py +265 -0
  30. hpo-examples/image-classification/transforms.py +183 -0
  31. hpo-examples/image-classification/trplib.py +1181 -0
  32. hpo-examples/image-classification/utils.py +465 -0
  33. hpo-examples/image-classification/vit_b_16/model_4.pth +3 -0
  34. hpo-examples/question-answering/qa/README.md +55 -0
  35. hpo-examples/question-answering/qa/all_results.json +15 -0
  36. hpo-examples/question-answering/qa/config.json +26 -0
  37. hpo-examples/question-answering/qa/eval_nbest_predictions.json +3 -0
  38. hpo-examples/question-answering/qa/eval_predictions.json +0 -0
  39. hpo-examples/question-answering/qa/eval_results.json +9 -0
  40. hpo-examples/question-answering/qa/model.safetensors +3 -0
  41. hpo-examples/question-answering/qa/runs/May15_03-24-14_cs-Precision-7960-Tower/events.out.tfevents.1747293859.cs-Precision-7960-Tower.147971.0 +3 -0
  42. hpo-examples/question-answering/qa/runs/May15_03-24-14_cs-Precision-7960-Tower/events.out.tfevents.1747297197.cs-Precision-7960-Tower.147971.1 +3 -0
  43. hpo-examples/question-answering/qa/special_tokens_map.json +7 -0
  44. hpo-examples/question-answering/qa/tokenizer.json +0 -0
  45. hpo-examples/question-answering/qa/tokenizer_config.json +56 -0
  46. hpo-examples/question-answering/qa/train_results.json +9 -0
  47. hpo-examples/question-answering/qa/trainer_state.json +245 -0
  48. hpo-examples/question-answering/qa/training_args.bin +3 -0
  49. hpo-examples/question-answering/qa/vocab.txt +0 -0
  50. hpo-examples/question-answering/requirements.txt +4 -0
.gitattributes CHANGED
@@ -37,3 +37,4 @@ qa/eval_nbest_predictions.json filter=lfs diff=lfs merge=lfs -text
37
  qa/sequential-policy-gradient.pdf filter=lfs diff=lfs merge=lfs -text
38
  sequential-policy-gradient.png filter=lfs diff=lfs merge=lfs -text
39
  examples/question-answering/qa/eval_nbest_predictions.json filter=lfs diff=lfs merge=lfs -text
 
 
37
  qa/sequential-policy-gradient.pdf filter=lfs diff=lfs merge=lfs -text
38
  sequential-policy-gradient.png filter=lfs diff=lfs merge=lfs -text
39
  examples/question-answering/qa/eval_nbest_predictions.json filter=lfs diff=lfs merge=lfs -text
40
+ hpo-examples/question-answering/qa/eval_nbest_predictions.json filter=lfs diff=lfs merge=lfs -text
hpo-examples/audio-classification/ac/README.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ license: apache-2.0
4
+ base_model: facebook/wav2vec2-base
5
+ tags:
6
+ - audio-classification
7
+ - generated_from_trainer
8
+ datasets:
9
+ - superb
10
+ metrics:
11
+ - accuracy
12
+ model-index:
13
+ - name: wav2vec2-base-ft-keyword-spotting
14
+ results:
15
+ - task:
16
+ name: Audio Classification
17
+ type: audio-classification
18
+ dataset:
19
+ name: superb
20
+ type: superb
21
+ config: ks
22
+ split: validation
23
+ args: ks
24
+ metrics:
25
+ - name: Accuracy
26
+ type: accuracy
27
+ value: 0.9826419535157399
28
+ ---
29
+
30
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
31
+ should probably proofread and complete it, then remove this comment. -->
32
+
33
+ # wav2vec2-base-ft-keyword-spotting
34
+
35
+ This model is a fine-tuned version of [facebook/wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) on the superb dataset.
36
+ It achieves the following results on the evaluation set:
37
+ - Loss: 0.0954
38
+ - Accuracy: 0.9826
39
+
40
+ ## Model description
41
+
42
+ More information needed
43
+
44
+ ## Intended uses & limitations
45
+
46
+ More information needed
47
+
48
+ ## Training and evaluation data
49
+
50
+ More information needed
51
+
52
+ ## Training procedure
53
+
54
+ ### Training hyperparameters
55
+
56
+ The following hyperparameters were used during training:
57
+ - learning_rate: 3e-05
58
+ - train_batch_size: 48
59
+ - eval_batch_size: 32
60
+ - seed: 0
61
+ - gradient_accumulation_steps: 4
62
+ - total_train_batch_size: 192
63
+ - optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
64
+ - lr_scheduler_type: linear
65
+ - lr_scheduler_warmup_ratio: 0.1
66
+ - num_epochs: 8.0
67
+ - mixed_precision_training: Native AMP
68
+
69
+ ### Training results
70
+
71
+ | Training Loss | Epoch | Step | Validation Loss | Accuracy |
72
+ |:-------------:|:------:|:----:|:---------------:|:--------:|
73
+ | 1.3624 | 1.0 | 267 | 1.1959 | 0.6546 |
74
+ | 0.3854 | 2.0 | 534 | 0.2675 | 0.9734 |
75
+ | 0.2473 | 3.0 | 801 | 0.1461 | 0.9768 |
76
+ | 0.1997 | 4.0 | 1068 | 0.1088 | 0.9804 |
77
+ | 0.1723 | 5.0 | 1335 | 0.0954 | 0.9826 |
78
+ | 0.1442 | 6.0 | 1602 | 0.0927 | 0.9813 |
79
+ | 0.1397 | 7.0 | 1869 | 0.0892 | 0.9812 |
80
+ | 0.1368 | 7.9728 | 2128 | 0.0896 | 0.9812 |
81
+
82
+
83
+ ### Framework versions
84
+
85
+ - Transformers 4.49.0
86
+ - Pytorch 2.6.0+cu118
87
+ - Datasets 3.3.1
88
+ - Tokenizers 0.21.0
hpo-examples/audio-classification/ac/all_results.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 7.972769953051643,
3
+ "eval_accuracy": 0.9826419535157399,
4
+ "eval_loss": 0.09542840719223022,
5
+ "eval_runtime": 5.5538,
6
+ "eval_samples_per_second": 1224.023,
7
+ "eval_steps_per_second": 38.352,
8
+ "total_flos": 3.767900833756416e+18,
9
+ "train_loss": 0.5178930132572812,
10
+ "train_runtime": 756.2923,
11
+ "train_samples_per_second": 540.468,
12
+ "train_steps_per_second": 2.814
13
+ }
hpo-examples/audio-classification/ac/config.json ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "facebook/wav2vec2-base",
3
+ "activation_dropout": 0.0,
4
+ "adapter_attn_dim": null,
5
+ "adapter_kernel_size": 3,
6
+ "adapter_stride": 2,
7
+ "add_adapter": false,
8
+ "apply_spec_augment": true,
9
+ "architectures": [
10
+ "Wav2Vec2ForSequenceClassification"
11
+ ],
12
+ "attention_dropout": 0.1,
13
+ "bos_token_id": 1,
14
+ "classifier_proj_size": 256,
15
+ "codevector_dim": 256,
16
+ "contrastive_logits_temperature": 0.1,
17
+ "conv_bias": false,
18
+ "conv_dim": [
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512,
24
+ 512,
25
+ 512
26
+ ],
27
+ "conv_kernel": [
28
+ 10,
29
+ 3,
30
+ 3,
31
+ 3,
32
+ 3,
33
+ 2,
34
+ 2
35
+ ],
36
+ "conv_stride": [
37
+ 5,
38
+ 2,
39
+ 2,
40
+ 2,
41
+ 2,
42
+ 2,
43
+ 2
44
+ ],
45
+ "ctc_loss_reduction": "sum",
46
+ "ctc_zero_infinity": false,
47
+ "diversity_loss_weight": 0.1,
48
+ "do_stable_layer_norm": false,
49
+ "eos_token_id": 2,
50
+ "feat_extract_activation": "gelu",
51
+ "feat_extract_norm": "group",
52
+ "feat_proj_dropout": 0.1,
53
+ "feat_quantizer_dropout": 0.0,
54
+ "final_dropout": 0.0,
55
+ "finetuning_task": "audio-classification",
56
+ "freeze_feat_extract_train": true,
57
+ "hidden_act": "gelu",
58
+ "hidden_dropout": 0.1,
59
+ "hidden_size": 768,
60
+ "id2label": {
61
+ "0": "yes",
62
+ "1": "no",
63
+ "10": "_silence_",
64
+ "11": "_unknown_",
65
+ "2": "up",
66
+ "3": "down",
67
+ "4": "left",
68
+ "5": "right",
69
+ "6": "on",
70
+ "7": "off",
71
+ "8": "stop",
72
+ "9": "go"
73
+ },
74
+ "initializer_range": 0.02,
75
+ "intermediate_size": 3072,
76
+ "label2id": {
77
+ "_silence_": "10",
78
+ "_unknown_": "11",
79
+ "down": "3",
80
+ "go": "9",
81
+ "left": "4",
82
+ "no": "1",
83
+ "off": "7",
84
+ "on": "6",
85
+ "right": "5",
86
+ "stop": "8",
87
+ "up": "2",
88
+ "yes": "0"
89
+ },
90
+ "layer_norm_eps": 1e-05,
91
+ "layerdrop": 0.0,
92
+ "mask_channel_length": 10,
93
+ "mask_channel_min_space": 1,
94
+ "mask_channel_other": 0.0,
95
+ "mask_channel_prob": 0.0,
96
+ "mask_channel_selection": "static",
97
+ "mask_feature_length": 10,
98
+ "mask_feature_min_masks": 0,
99
+ "mask_feature_prob": 0.0,
100
+ "mask_time_length": 10,
101
+ "mask_time_min_masks": 2,
102
+ "mask_time_min_space": 1,
103
+ "mask_time_other": 0.0,
104
+ "mask_time_prob": 0.05,
105
+ "mask_time_selection": "static",
106
+ "model_type": "wav2vec2",
107
+ "no_mask_channel_overlap": false,
108
+ "no_mask_time_overlap": false,
109
+ "num_adapter_layers": 3,
110
+ "num_attention_heads": 12,
111
+ "num_codevector_groups": 2,
112
+ "num_codevectors_per_group": 320,
113
+ "num_conv_pos_embedding_groups": 16,
114
+ "num_conv_pos_embeddings": 128,
115
+ "num_feat_extract_layers": 7,
116
+ "num_hidden_layers": 12,
117
+ "num_negatives": 100,
118
+ "output_hidden_size": 768,
119
+ "pad_token_id": 0,
120
+ "proj_codevector_dim": 256,
121
+ "tdnn_dilation": [
122
+ 1,
123
+ 2,
124
+ 3,
125
+ 1,
126
+ 1
127
+ ],
128
+ "tdnn_dim": [
129
+ 512,
130
+ 512,
131
+ 512,
132
+ 512,
133
+ 1500
134
+ ],
135
+ "tdnn_kernel": [
136
+ 5,
137
+ 3,
138
+ 3,
139
+ 1,
140
+ 1
141
+ ],
142
+ "torch_dtype": "float32",
143
+ "transformers_version": "4.49.0",
144
+ "use_weighted_layer_sum": false,
145
+ "vocab_size": 32,
146
+ "xvector_output_dim": 512
147
+ }
hpo-examples/audio-classification/ac/eval_results.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 7.972769953051643,
3
+ "eval_accuracy": 0.9826419535157399,
4
+ "eval_loss": 0.09542840719223022,
5
+ "eval_runtime": 5.5538,
6
+ "eval_samples_per_second": 1224.023,
7
+ "eval_steps_per_second": 38.352
8
+ }
hpo-examples/audio-classification/ac/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6e4f1c85d883f3e41ebfab4cd7752ab2e6d6b968b847795be22e1e0662657a3
3
+ size 385400352
hpo-examples/audio-classification/ac/preprocessor_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0.0,
7
+ "return_attention_mask": false,
8
+ "sampling_rate": 16000
9
+ }
hpo-examples/audio-classification/ac/runs/May15_03-06-03_cs-Precision-7960-Tower/events.out.tfevents.1747292768.cs-Precision-7960-Tower.146737.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75015326d15cb1786d8788b228e95bd7e952bad9e8e8be4ed02764cc17c8464c
3
+ size 54953
hpo-examples/audio-classification/ac/runs/May15_03-06-03_cs-Precision-7960-Tower/events.out.tfevents.1747293535.cs-Precision-7960-Tower.146737.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:215a99aa1f96db505ff53265c988d18ddcf70f16a1f13e713316a0354b768356
3
+ size 411
hpo-examples/audio-classification/ac/train_results.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 7.972769953051643,
3
+ "total_flos": 3.767900833756416e+18,
4
+ "train_loss": 0.5178930132572812,
5
+ "train_runtime": 756.2923,
6
+ "train_samples_per_second": 540.468,
7
+ "train_steps_per_second": 2.814
8
+ }
hpo-examples/audio-classification/ac/trainer_state.json ADDED
@@ -0,0 +1,1598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": 0.9826419535157399,
3
+ "best_model_checkpoint": "wav2vec2-base-ft-keyword-spotting/checkpoint-1335",
4
+ "epoch": 7.972769953051643,
5
+ "eval_steps": 500,
6
+ "global_step": 2128,
7
+ "is_hyper_param_search": false,
8
+ "is_local_process_zero": true,
9
+ "is_world_process_zero": true,
10
+ "log_history": [
11
+ {
12
+ "epoch": 0.03755868544600939,
13
+ "grad_norm": 2.0377416610717773,
14
+ "learning_rate": 1.4084507042253521e-06,
15
+ "loss": 3.8687,
16
+ "step": 10
17
+ },
18
+ {
19
+ "epoch": 0.07511737089201878,
20
+ "grad_norm": 3.055781602859497,
21
+ "learning_rate": 2.8169014084507042e-06,
22
+ "loss": 4.1156,
23
+ "step": 20
24
+ },
25
+ {
26
+ "epoch": 0.11267605633802817,
27
+ "grad_norm": 3.383268356323242,
28
+ "learning_rate": 4.225352112676057e-06,
29
+ "loss": 4.0885,
30
+ "step": 30
31
+ },
32
+ {
33
+ "epoch": 0.15023474178403756,
34
+ "grad_norm": 3.8566606044769287,
35
+ "learning_rate": 5.6338028169014084e-06,
36
+ "loss": 3.9316,
37
+ "step": 40
38
+ },
39
+ {
40
+ "epoch": 0.18779342723004694,
41
+ "grad_norm": 5.065456867218018,
42
+ "learning_rate": 7.042253521126761e-06,
43
+ "loss": 3.6474,
44
+ "step": 50
45
+ },
46
+ {
47
+ "epoch": 0.22535211267605634,
48
+ "grad_norm": 5.89341926574707,
49
+ "learning_rate": 8.450704225352114e-06,
50
+ "loss": 3.2124,
51
+ "step": 60
52
+ },
53
+ {
54
+ "epoch": 0.26291079812206575,
55
+ "grad_norm": 5.9929399490356445,
56
+ "learning_rate": 9.859154929577466e-06,
57
+ "loss": 2.756,
58
+ "step": 70
59
+ },
60
+ {
61
+ "epoch": 0.3004694835680751,
62
+ "grad_norm": 5.689433574676514,
63
+ "learning_rate": 1.1267605633802817e-05,
64
+ "loss": 2.4596,
65
+ "step": 80
66
+ },
67
+ {
68
+ "epoch": 0.3380281690140845,
69
+ "grad_norm": 4.89589262008667,
70
+ "learning_rate": 1.267605633802817e-05,
71
+ "loss": 2.2638,
72
+ "step": 90
73
+ },
74
+ {
75
+ "epoch": 0.3755868544600939,
76
+ "grad_norm": 4.8666839599609375,
77
+ "learning_rate": 1.4084507042253522e-05,
78
+ "loss": 2.1166,
79
+ "step": 100
80
+ },
81
+ {
82
+ "epoch": 0.4131455399061033,
83
+ "grad_norm": 4.466708660125732,
84
+ "learning_rate": 1.5492957746478876e-05,
85
+ "loss": 2.0048,
86
+ "step": 110
87
+ },
88
+ {
89
+ "epoch": 0.4507042253521127,
90
+ "grad_norm": 3.676050901412964,
91
+ "learning_rate": 1.6901408450704228e-05,
92
+ "loss": 1.9138,
93
+ "step": 120
94
+ },
95
+ {
96
+ "epoch": 0.48826291079812206,
97
+ "grad_norm": 2.183825731277466,
98
+ "learning_rate": 1.830985915492958e-05,
99
+ "loss": 1.863,
100
+ "step": 130
101
+ },
102
+ {
103
+ "epoch": 0.5258215962441315,
104
+ "grad_norm": 2.075413465499878,
105
+ "learning_rate": 1.9718309859154933e-05,
106
+ "loss": 1.7616,
107
+ "step": 140
108
+ },
109
+ {
110
+ "epoch": 0.5633802816901409,
111
+ "grad_norm": 0.8534318208694458,
112
+ "learning_rate": 2.112676056338028e-05,
113
+ "loss": 1.7185,
114
+ "step": 150
115
+ },
116
+ {
117
+ "epoch": 0.6009389671361502,
118
+ "grad_norm": 0.9039830565452576,
119
+ "learning_rate": 2.2535211267605634e-05,
120
+ "loss": 1.8054,
121
+ "step": 160
122
+ },
123
+ {
124
+ "epoch": 0.6384976525821596,
125
+ "grad_norm": 1.32124662399292,
126
+ "learning_rate": 2.3943661971830986e-05,
127
+ "loss": 1.7367,
128
+ "step": 170
129
+ },
130
+ {
131
+ "epoch": 0.676056338028169,
132
+ "grad_norm": 1.232069969177246,
133
+ "learning_rate": 2.535211267605634e-05,
134
+ "loss": 1.7423,
135
+ "step": 180
136
+ },
137
+ {
138
+ "epoch": 0.7136150234741784,
139
+ "grad_norm": 1.9570960998535156,
140
+ "learning_rate": 2.676056338028169e-05,
141
+ "loss": 1.6132,
142
+ "step": 190
143
+ },
144
+ {
145
+ "epoch": 0.7511737089201878,
146
+ "grad_norm": 2.4463119506835938,
147
+ "learning_rate": 2.8169014084507043e-05,
148
+ "loss": 1.6099,
149
+ "step": 200
150
+ },
151
+ {
152
+ "epoch": 0.7887323943661971,
153
+ "grad_norm": 6.601908206939697,
154
+ "learning_rate": 2.9577464788732395e-05,
155
+ "loss": 1.6043,
156
+ "step": 210
157
+ },
158
+ {
159
+ "epoch": 0.8262910798122066,
160
+ "grad_norm": 3.225101947784424,
161
+ "learning_rate": 2.989033942558747e-05,
162
+ "loss": 1.5621,
163
+ "step": 220
164
+ },
165
+ {
166
+ "epoch": 0.863849765258216,
167
+ "grad_norm": 3.698263645172119,
168
+ "learning_rate": 2.9733681462140994e-05,
169
+ "loss": 1.514,
170
+ "step": 230
171
+ },
172
+ {
173
+ "epoch": 0.9014084507042254,
174
+ "grad_norm": 5.209756374359131,
175
+ "learning_rate": 2.9577023498694518e-05,
176
+ "loss": 1.4532,
177
+ "step": 240
178
+ },
179
+ {
180
+ "epoch": 0.9389671361502347,
181
+ "grad_norm": 2.1304848194122314,
182
+ "learning_rate": 2.9420365535248042e-05,
183
+ "loss": 1.4312,
184
+ "step": 250
185
+ },
186
+ {
187
+ "epoch": 0.9765258215962441,
188
+ "grad_norm": 4.837350368499756,
189
+ "learning_rate": 2.926370757180157e-05,
190
+ "loss": 1.3624,
191
+ "step": 260
192
+ },
193
+ {
194
+ "epoch": 1.0,
195
+ "eval_accuracy": 0.6546042953809944,
196
+ "eval_loss": 1.19585382938385,
197
+ "eval_runtime": 4.9178,
198
+ "eval_samples_per_second": 1382.328,
199
+ "eval_steps_per_second": 43.312,
200
+ "step": 267
201
+ },
202
+ {
203
+ "epoch": 1.0112676056338028,
204
+ "grad_norm": 4.779292106628418,
205
+ "learning_rate": 2.9107049608355094e-05,
206
+ "loss": 1.2541,
207
+ "step": 270
208
+ },
209
+ {
210
+ "epoch": 1.0488262910798123,
211
+ "grad_norm": 3.60760498046875,
212
+ "learning_rate": 2.8950391644908618e-05,
213
+ "loss": 1.2271,
214
+ "step": 280
215
+ },
216
+ {
217
+ "epoch": 1.0863849765258216,
218
+ "grad_norm": 2.3788599967956543,
219
+ "learning_rate": 2.8793733681462142e-05,
220
+ "loss": 1.2335,
221
+ "step": 290
222
+ },
223
+ {
224
+ "epoch": 1.123943661971831,
225
+ "grad_norm": 3.353325843811035,
226
+ "learning_rate": 2.8637075718015666e-05,
227
+ "loss": 1.1613,
228
+ "step": 300
229
+ },
230
+ {
231
+ "epoch": 1.1615023474178403,
232
+ "grad_norm": 4.326411247253418,
233
+ "learning_rate": 2.8480417754569193e-05,
234
+ "loss": 1.0754,
235
+ "step": 310
236
+ },
237
+ {
238
+ "epoch": 1.1990610328638498,
239
+ "grad_norm": 3.1939706802368164,
240
+ "learning_rate": 2.8323759791122717e-05,
241
+ "loss": 1.0353,
242
+ "step": 320
243
+ },
244
+ {
245
+ "epoch": 1.236619718309859,
246
+ "grad_norm": 2.8827011585235596,
247
+ "learning_rate": 2.816710182767624e-05,
248
+ "loss": 0.9806,
249
+ "step": 330
250
+ },
251
+ {
252
+ "epoch": 1.2741784037558685,
253
+ "grad_norm": 3.910698652267456,
254
+ "learning_rate": 2.8010443864229766e-05,
255
+ "loss": 1.0813,
256
+ "step": 340
257
+ },
258
+ {
259
+ "epoch": 1.3117370892018778,
260
+ "grad_norm": 3.5916378498077393,
261
+ "learning_rate": 2.7853785900783293e-05,
262
+ "loss": 0.9792,
263
+ "step": 350
264
+ },
265
+ {
266
+ "epoch": 1.3492957746478873,
267
+ "grad_norm": 2.6981167793273926,
268
+ "learning_rate": 2.7697127937336817e-05,
269
+ "loss": 0.9231,
270
+ "step": 360
271
+ },
272
+ {
273
+ "epoch": 1.3868544600938968,
274
+ "grad_norm": 5.702897071838379,
275
+ "learning_rate": 2.754046997389034e-05,
276
+ "loss": 0.9435,
277
+ "step": 370
278
+ },
279
+ {
280
+ "epoch": 1.424413145539906,
281
+ "grad_norm": 4.622363090515137,
282
+ "learning_rate": 2.7383812010443865e-05,
283
+ "loss": 0.8449,
284
+ "step": 380
285
+ },
286
+ {
287
+ "epoch": 1.4619718309859155,
288
+ "grad_norm": 2.2103636264801025,
289
+ "learning_rate": 2.7227154046997393e-05,
290
+ "loss": 0.7713,
291
+ "step": 390
292
+ },
293
+ {
294
+ "epoch": 1.4995305164319248,
295
+ "grad_norm": 4.545182228088379,
296
+ "learning_rate": 2.7070496083550917e-05,
297
+ "loss": 0.7719,
298
+ "step": 400
299
+ },
300
+ {
301
+ "epoch": 1.5370892018779343,
302
+ "grad_norm": 6.883026599884033,
303
+ "learning_rate": 2.691383812010444e-05,
304
+ "loss": 0.7564,
305
+ "step": 410
306
+ },
307
+ {
308
+ "epoch": 1.5746478873239438,
309
+ "grad_norm": 4.770920276641846,
310
+ "learning_rate": 2.6757180156657965e-05,
311
+ "loss": 0.6994,
312
+ "step": 420
313
+ },
314
+ {
315
+ "epoch": 1.612206572769953,
316
+ "grad_norm": 4.413459300994873,
317
+ "learning_rate": 2.660052219321149e-05,
318
+ "loss": 0.6313,
319
+ "step": 430
320
+ },
321
+ {
322
+ "epoch": 1.6497652582159623,
323
+ "grad_norm": 2.0261390209198,
324
+ "learning_rate": 2.6443864229765013e-05,
325
+ "loss": 0.6017,
326
+ "step": 440
327
+ },
328
+ {
329
+ "epoch": 1.6873239436619718,
330
+ "grad_norm": 5.67121696472168,
331
+ "learning_rate": 2.6287206266318537e-05,
332
+ "loss": 0.5792,
333
+ "step": 450
334
+ },
335
+ {
336
+ "epoch": 1.7248826291079813,
337
+ "grad_norm": 2.573594808578491,
338
+ "learning_rate": 2.6146214099216712e-05,
339
+ "loss": 0.545,
340
+ "step": 460
341
+ },
342
+ {
343
+ "epoch": 1.7624413145539906,
344
+ "grad_norm": 4.145854949951172,
345
+ "learning_rate": 2.5989556135770236e-05,
346
+ "loss": 0.4907,
347
+ "step": 470
348
+ },
349
+ {
350
+ "epoch": 1.8,
351
+ "grad_norm": 1.7418975830078125,
352
+ "learning_rate": 2.583289817232376e-05,
353
+ "loss": 0.485,
354
+ "step": 480
355
+ },
356
+ {
357
+ "epoch": 1.8375586854460093,
358
+ "grad_norm": 4.651867866516113,
359
+ "learning_rate": 2.5676240208877287e-05,
360
+ "loss": 0.4572,
361
+ "step": 490
362
+ },
363
+ {
364
+ "epoch": 1.8751173708920188,
365
+ "grad_norm": 4.849829196929932,
366
+ "learning_rate": 2.551958224543081e-05,
367
+ "loss": 0.4864,
368
+ "step": 500
369
+ },
370
+ {
371
+ "epoch": 1.9126760563380283,
372
+ "grad_norm": 2.631229877471924,
373
+ "learning_rate": 2.5362924281984335e-05,
374
+ "loss": 0.4035,
375
+ "step": 510
376
+ },
377
+ {
378
+ "epoch": 1.9502347417840376,
379
+ "grad_norm": 5.099828243255615,
380
+ "learning_rate": 2.520626631853786e-05,
381
+ "loss": 0.3818,
382
+ "step": 520
383
+ },
384
+ {
385
+ "epoch": 1.9877934272300468,
386
+ "grad_norm": 3.25174617767334,
387
+ "learning_rate": 2.5049608355091387e-05,
388
+ "loss": 0.3854,
389
+ "step": 530
390
+ },
391
+ {
392
+ "epoch": 2.0,
393
+ "eval_accuracy": 0.9733745219182113,
394
+ "eval_loss": 0.2675245702266693,
395
+ "eval_runtime": 5.0739,
396
+ "eval_samples_per_second": 1339.801,
397
+ "eval_steps_per_second": 41.98,
398
+ "step": 534
399
+ },
400
+ {
401
+ "epoch": 2.0225352112676056,
402
+ "grad_norm": 4.533545017242432,
403
+ "learning_rate": 2.489295039164491e-05,
404
+ "loss": 0.3589,
405
+ "step": 540
406
+ },
407
+ {
408
+ "epoch": 2.060093896713615,
409
+ "grad_norm": 4.4245991706848145,
410
+ "learning_rate": 2.4736292428198435e-05,
411
+ "loss": 0.378,
412
+ "step": 550
413
+ },
414
+ {
415
+ "epoch": 2.0976525821596246,
416
+ "grad_norm": 5.778880596160889,
417
+ "learning_rate": 2.457963446475196e-05,
418
+ "loss": 0.3653,
419
+ "step": 560
420
+ },
421
+ {
422
+ "epoch": 2.1352112676056336,
423
+ "grad_norm": 3.5573890209198,
424
+ "learning_rate": 2.4422976501305487e-05,
425
+ "loss": 0.3107,
426
+ "step": 570
427
+ },
428
+ {
429
+ "epoch": 2.172769953051643,
430
+ "grad_norm": 3.655824899673462,
431
+ "learning_rate": 2.426631853785901e-05,
432
+ "loss": 0.3405,
433
+ "step": 580
434
+ },
435
+ {
436
+ "epoch": 2.2103286384976526,
437
+ "grad_norm": 2.430022954940796,
438
+ "learning_rate": 2.4109660574412535e-05,
439
+ "loss": 0.3298,
440
+ "step": 590
441
+ },
442
+ {
443
+ "epoch": 2.247887323943662,
444
+ "grad_norm": 2.9207568168640137,
445
+ "learning_rate": 2.3953002610966055e-05,
446
+ "loss": 0.3022,
447
+ "step": 600
448
+ },
449
+ {
450
+ "epoch": 2.2854460093896716,
451
+ "grad_norm": 4.8787007331848145,
452
+ "learning_rate": 2.3796344647519583e-05,
453
+ "loss": 0.3991,
454
+ "step": 610
455
+ },
456
+ {
457
+ "epoch": 2.3230046948356806,
458
+ "grad_norm": 3.0268468856811523,
459
+ "learning_rate": 2.3639686684073107e-05,
460
+ "loss": 0.3159,
461
+ "step": 620
462
+ },
463
+ {
464
+ "epoch": 2.36056338028169,
465
+ "grad_norm": 2.6611557006835938,
466
+ "learning_rate": 2.348302872062663e-05,
467
+ "loss": 0.2868,
468
+ "step": 630
469
+ },
470
+ {
471
+ "epoch": 2.3981220657276996,
472
+ "grad_norm": 2.485551595687866,
473
+ "learning_rate": 2.3326370757180155e-05,
474
+ "loss": 0.3032,
475
+ "step": 640
476
+ },
477
+ {
478
+ "epoch": 2.435680751173709,
479
+ "grad_norm": 4.556153297424316,
480
+ "learning_rate": 2.316971279373368e-05,
481
+ "loss": 0.2985,
482
+ "step": 650
483
+ },
484
+ {
485
+ "epoch": 2.473239436619718,
486
+ "grad_norm": 5.270796298980713,
487
+ "learning_rate": 2.3013054830287207e-05,
488
+ "loss": 0.2839,
489
+ "step": 660
490
+ },
491
+ {
492
+ "epoch": 2.5107981220657276,
493
+ "grad_norm": 3.347005844116211,
494
+ "learning_rate": 2.285639686684073e-05,
495
+ "loss": 0.2871,
496
+ "step": 670
497
+ },
498
+ {
499
+ "epoch": 2.548356807511737,
500
+ "grad_norm": 5.236591815948486,
501
+ "learning_rate": 2.2699738903394255e-05,
502
+ "loss": 0.3028,
503
+ "step": 680
504
+ },
505
+ {
506
+ "epoch": 2.5859154929577466,
507
+ "grad_norm": 2.995059013366699,
508
+ "learning_rate": 2.254308093994778e-05,
509
+ "loss": 0.2537,
510
+ "step": 690
511
+ },
512
+ {
513
+ "epoch": 2.6234741784037556,
514
+ "grad_norm": 2.805640459060669,
515
+ "learning_rate": 2.2386422976501306e-05,
516
+ "loss": 0.297,
517
+ "step": 700
518
+ },
519
+ {
520
+ "epoch": 2.661032863849765,
521
+ "grad_norm": 3.0646071434020996,
522
+ "learning_rate": 2.222976501305483e-05,
523
+ "loss": 0.2453,
524
+ "step": 710
525
+ },
526
+ {
527
+ "epoch": 2.6985915492957746,
528
+ "grad_norm": 3.6719613075256348,
529
+ "learning_rate": 2.2073107049608354e-05,
530
+ "loss": 0.2655,
531
+ "step": 720
532
+ },
533
+ {
534
+ "epoch": 2.736150234741784,
535
+ "grad_norm": 3.2248122692108154,
536
+ "learning_rate": 2.191644908616188e-05,
537
+ "loss": 0.2297,
538
+ "step": 730
539
+ },
540
+ {
541
+ "epoch": 2.7737089201877936,
542
+ "grad_norm": 3.769843578338623,
543
+ "learning_rate": 2.1759791122715406e-05,
544
+ "loss": 0.2548,
545
+ "step": 740
546
+ },
547
+ {
548
+ "epoch": 2.8112676056338026,
549
+ "grad_norm": 3.6679906845092773,
550
+ "learning_rate": 2.160313315926893e-05,
551
+ "loss": 0.2836,
552
+ "step": 750
553
+ },
554
+ {
555
+ "epoch": 2.848826291079812,
556
+ "grad_norm": 1.6924936771392822,
557
+ "learning_rate": 2.1446475195822454e-05,
558
+ "loss": 0.2555,
559
+ "step": 760
560
+ },
561
+ {
562
+ "epoch": 2.8863849765258216,
563
+ "grad_norm": 2.1275901794433594,
564
+ "learning_rate": 2.1289817232375978e-05,
565
+ "loss": 0.2334,
566
+ "step": 770
567
+ },
568
+ {
569
+ "epoch": 2.923943661971831,
570
+ "grad_norm": 6.528135299682617,
571
+ "learning_rate": 2.1133159268929506e-05,
572
+ "loss": 0.2544,
573
+ "step": 780
574
+ },
575
+ {
576
+ "epoch": 2.9615023474178406,
577
+ "grad_norm": 2.4497199058532715,
578
+ "learning_rate": 2.097650130548303e-05,
579
+ "loss": 0.2628,
580
+ "step": 790
581
+ },
582
+ {
583
+ "epoch": 2.9990610328638496,
584
+ "grad_norm": 2.278947591781616,
585
+ "learning_rate": 2.0819843342036554e-05,
586
+ "loss": 0.2473,
587
+ "step": 800
588
+ },
589
+ {
590
+ "epoch": 3.0,
591
+ "eval_accuracy": 0.9767578699617535,
592
+ "eval_loss": 0.1461225152015686,
593
+ "eval_runtime": 5.0057,
594
+ "eval_samples_per_second": 1358.045,
595
+ "eval_steps_per_second": 42.551,
596
+ "step": 801
597
+ },
598
+ {
599
+ "epoch": 3.0338028169014084,
600
+ "grad_norm": 3.1185402870178223,
601
+ "learning_rate": 2.0663185378590078e-05,
602
+ "loss": 0.2245,
603
+ "step": 810
604
+ },
605
+ {
606
+ "epoch": 3.071361502347418,
607
+ "grad_norm": 2.456102132797241,
608
+ "learning_rate": 2.0506527415143602e-05,
609
+ "loss": 0.2423,
610
+ "step": 820
611
+ },
612
+ {
613
+ "epoch": 3.1089201877934274,
614
+ "grad_norm": 2.9463231563568115,
615
+ "learning_rate": 2.034986945169713e-05,
616
+ "loss": 0.2274,
617
+ "step": 830
618
+ },
619
+ {
620
+ "epoch": 3.1464788732394364,
621
+ "grad_norm": 3.5940473079681396,
622
+ "learning_rate": 2.0193211488250653e-05,
623
+ "loss": 0.2368,
624
+ "step": 840
625
+ },
626
+ {
627
+ "epoch": 3.184037558685446,
628
+ "grad_norm": 4.721577167510986,
629
+ "learning_rate": 2.0036553524804177e-05,
630
+ "loss": 0.2554,
631
+ "step": 850
632
+ },
633
+ {
634
+ "epoch": 3.2215962441314554,
635
+ "grad_norm": 2.496495485305786,
636
+ "learning_rate": 1.98798955613577e-05,
637
+ "loss": 0.2363,
638
+ "step": 860
639
+ },
640
+ {
641
+ "epoch": 3.259154929577465,
642
+ "grad_norm": 3.0665740966796875,
643
+ "learning_rate": 1.972323759791123e-05,
644
+ "loss": 0.2248,
645
+ "step": 870
646
+ },
647
+ {
648
+ "epoch": 3.2967136150234744,
649
+ "grad_norm": 4.336172580718994,
650
+ "learning_rate": 1.9566579634464753e-05,
651
+ "loss": 0.1922,
652
+ "step": 880
653
+ },
654
+ {
655
+ "epoch": 3.3342723004694834,
656
+ "grad_norm": 4.110763072967529,
657
+ "learning_rate": 1.9409921671018277e-05,
658
+ "loss": 0.1965,
659
+ "step": 890
660
+ },
661
+ {
662
+ "epoch": 3.371830985915493,
663
+ "grad_norm": 1.9457247257232666,
664
+ "learning_rate": 1.92532637075718e-05,
665
+ "loss": 0.2258,
666
+ "step": 900
667
+ },
668
+ {
669
+ "epoch": 3.4093896713615024,
670
+ "grad_norm": 2.719369411468506,
671
+ "learning_rate": 1.909660574412533e-05,
672
+ "loss": 0.2184,
673
+ "step": 910
674
+ },
675
+ {
676
+ "epoch": 3.446948356807512,
677
+ "grad_norm": 3.438279151916504,
678
+ "learning_rate": 1.8939947780678853e-05,
679
+ "loss": 0.1964,
680
+ "step": 920
681
+ },
682
+ {
683
+ "epoch": 3.4845070422535214,
684
+ "grad_norm": 3.2813045978546143,
685
+ "learning_rate": 1.8783289817232377e-05,
686
+ "loss": 0.2348,
687
+ "step": 930
688
+ },
689
+ {
690
+ "epoch": 3.5220657276995304,
691
+ "grad_norm": 4.151478290557861,
692
+ "learning_rate": 1.86266318537859e-05,
693
+ "loss": 0.2004,
694
+ "step": 940
695
+ },
696
+ {
697
+ "epoch": 3.55962441314554,
698
+ "grad_norm": 3.4271771907806396,
699
+ "learning_rate": 1.8469973890339425e-05,
700
+ "loss": 0.2039,
701
+ "step": 950
702
+ },
703
+ {
704
+ "epoch": 3.5971830985915494,
705
+ "grad_norm": 4.0341901779174805,
706
+ "learning_rate": 1.8313315926892952e-05,
707
+ "loss": 0.1997,
708
+ "step": 960
709
+ },
710
+ {
711
+ "epoch": 3.6347417840375584,
712
+ "grad_norm": 4.762091636657715,
713
+ "learning_rate": 1.8156657963446476e-05,
714
+ "loss": 0.2153,
715
+ "step": 970
716
+ },
717
+ {
718
+ "epoch": 3.672300469483568,
719
+ "grad_norm": 3.3214402198791504,
720
+ "learning_rate": 1.8e-05,
721
+ "loss": 0.1801,
722
+ "step": 980
723
+ },
724
+ {
725
+ "epoch": 3.7098591549295774,
726
+ "grad_norm": 3.84503173828125,
727
+ "learning_rate": 1.7843342036553525e-05,
728
+ "loss": 0.2106,
729
+ "step": 990
730
+ },
731
+ {
732
+ "epoch": 3.747417840375587,
733
+ "grad_norm": 3.303781747817993,
734
+ "learning_rate": 1.7686684073107052e-05,
735
+ "loss": 0.1965,
736
+ "step": 1000
737
+ },
738
+ {
739
+ "epoch": 3.7849765258215964,
740
+ "grad_norm": 2.691159248352051,
741
+ "learning_rate": 1.7530026109660576e-05,
742
+ "loss": 0.193,
743
+ "step": 1010
744
+ },
745
+ {
746
+ "epoch": 3.8225352112676054,
747
+ "grad_norm": 4.134768009185791,
748
+ "learning_rate": 1.73733681462141e-05,
749
+ "loss": 0.1908,
750
+ "step": 1020
751
+ },
752
+ {
753
+ "epoch": 3.860093896713615,
754
+ "grad_norm": 2.9195241928100586,
755
+ "learning_rate": 1.7216710182767624e-05,
756
+ "loss": 0.1886,
757
+ "step": 1030
758
+ },
759
+ {
760
+ "epoch": 3.8976525821596244,
761
+ "grad_norm": 3.795133352279663,
762
+ "learning_rate": 1.706005221932115e-05,
763
+ "loss": 0.2007,
764
+ "step": 1040
765
+ },
766
+ {
767
+ "epoch": 3.935211267605634,
768
+ "grad_norm": 3.9436607360839844,
769
+ "learning_rate": 1.6903394255874676e-05,
770
+ "loss": 0.1834,
771
+ "step": 1050
772
+ },
773
+ {
774
+ "epoch": 3.9727699530516434,
775
+ "grad_norm": 3.4115564823150635,
776
+ "learning_rate": 1.67467362924282e-05,
777
+ "loss": 0.1997,
778
+ "step": 1060
779
+ },
780
+ {
781
+ "epoch": 4.0,
782
+ "eval_accuracy": 0.980435422182995,
783
+ "eval_loss": 0.10877315700054169,
784
+ "eval_runtime": 4.9191,
785
+ "eval_samples_per_second": 1381.955,
786
+ "eval_steps_per_second": 43.3,
787
+ "step": 1068
788
+ },
789
+ {
790
+ "epoch": 4.007511737089202,
791
+ "grad_norm": 5.121041774749756,
792
+ "learning_rate": 1.6590078328981724e-05,
793
+ "loss": 0.1785,
794
+ "step": 1070
795
+ },
796
+ {
797
+ "epoch": 4.045070422535211,
798
+ "grad_norm": 2.908527374267578,
799
+ "learning_rate": 1.643342036553525e-05,
800
+ "loss": 0.1678,
801
+ "step": 1080
802
+ },
803
+ {
804
+ "epoch": 4.08262910798122,
805
+ "grad_norm": 1.9687402248382568,
806
+ "learning_rate": 1.6276762402088775e-05,
807
+ "loss": 0.192,
808
+ "step": 1090
809
+ },
810
+ {
811
+ "epoch": 4.12018779342723,
812
+ "grad_norm": 2.722937822341919,
813
+ "learning_rate": 1.61201044386423e-05,
814
+ "loss": 0.1983,
815
+ "step": 1100
816
+ },
817
+ {
818
+ "epoch": 4.157746478873239,
819
+ "grad_norm": 2.3741490840911865,
820
+ "learning_rate": 1.5963446475195823e-05,
821
+ "loss": 0.2145,
822
+ "step": 1110
823
+ },
824
+ {
825
+ "epoch": 4.195305164319249,
826
+ "grad_norm": 2.653414011001587,
827
+ "learning_rate": 1.5806788511749348e-05,
828
+ "loss": 0.1701,
829
+ "step": 1120
830
+ },
831
+ {
832
+ "epoch": 4.232863849765258,
833
+ "grad_norm": 3.444087266921997,
834
+ "learning_rate": 1.5650130548302875e-05,
835
+ "loss": 0.2047,
836
+ "step": 1130
837
+ },
838
+ {
839
+ "epoch": 4.270422535211267,
840
+ "grad_norm": 2.024235486984253,
841
+ "learning_rate": 1.54934725848564e-05,
842
+ "loss": 0.1817,
843
+ "step": 1140
844
+ },
845
+ {
846
+ "epoch": 4.307981220657277,
847
+ "grad_norm": 2.742171049118042,
848
+ "learning_rate": 1.533681462140992e-05,
849
+ "loss": 0.1723,
850
+ "step": 1150
851
+ },
852
+ {
853
+ "epoch": 4.345539906103286,
854
+ "grad_norm": 3.3700480461120605,
855
+ "learning_rate": 1.5180156657963446e-05,
856
+ "loss": 0.17,
857
+ "step": 1160
858
+ },
859
+ {
860
+ "epoch": 4.383098591549296,
861
+ "grad_norm": 2.552915573120117,
862
+ "learning_rate": 1.5023498694516973e-05,
863
+ "loss": 0.1802,
864
+ "step": 1170
865
+ },
866
+ {
867
+ "epoch": 4.420657276995305,
868
+ "grad_norm": 3.3317511081695557,
869
+ "learning_rate": 1.4866840731070497e-05,
870
+ "loss": 0.1933,
871
+ "step": 1180
872
+ },
873
+ {
874
+ "epoch": 4.458215962441314,
875
+ "grad_norm": 1.9266548156738281,
876
+ "learning_rate": 1.4710182767624021e-05,
877
+ "loss": 0.1739,
878
+ "step": 1190
879
+ },
880
+ {
881
+ "epoch": 4.495774647887324,
882
+ "grad_norm": 2.1459243297576904,
883
+ "learning_rate": 1.4553524804177547e-05,
884
+ "loss": 0.1599,
885
+ "step": 1200
886
+ },
887
+ {
888
+ "epoch": 4.533333333333333,
889
+ "grad_norm": 3.9314770698547363,
890
+ "learning_rate": 1.4396866840731071e-05,
891
+ "loss": 0.1958,
892
+ "step": 1210
893
+ },
894
+ {
895
+ "epoch": 4.570892018779343,
896
+ "grad_norm": 2.6377363204956055,
897
+ "learning_rate": 1.4240208877284597e-05,
898
+ "loss": 0.1604,
899
+ "step": 1220
900
+ },
901
+ {
902
+ "epoch": 4.608450704225352,
903
+ "grad_norm": 2.810866594314575,
904
+ "learning_rate": 1.408355091383812e-05,
905
+ "loss": 0.1495,
906
+ "step": 1230
907
+ },
908
+ {
909
+ "epoch": 4.646009389671361,
910
+ "grad_norm": 2.2084455490112305,
911
+ "learning_rate": 1.3926892950391646e-05,
912
+ "loss": 0.185,
913
+ "step": 1240
914
+ },
915
+ {
916
+ "epoch": 4.683568075117371,
917
+ "grad_norm": 2.7217283248901367,
918
+ "learning_rate": 1.377023498694517e-05,
919
+ "loss": 0.1757,
920
+ "step": 1250
921
+ },
922
+ {
923
+ "epoch": 4.72112676056338,
924
+ "grad_norm": 3.075267791748047,
925
+ "learning_rate": 1.3613577023498696e-05,
926
+ "loss": 0.1814,
927
+ "step": 1260
928
+ },
929
+ {
930
+ "epoch": 4.758685446009389,
931
+ "grad_norm": 3.2452406883239746,
932
+ "learning_rate": 1.345691906005222e-05,
933
+ "loss": 0.1622,
934
+ "step": 1270
935
+ },
936
+ {
937
+ "epoch": 4.796244131455399,
938
+ "grad_norm": 2.712754487991333,
939
+ "learning_rate": 1.3300261096605744e-05,
940
+ "loss": 0.1714,
941
+ "step": 1280
942
+ },
943
+ {
944
+ "epoch": 4.833802816901408,
945
+ "grad_norm": 1.6795600652694702,
946
+ "learning_rate": 1.3143603133159269e-05,
947
+ "loss": 0.1519,
948
+ "step": 1290
949
+ },
950
+ {
951
+ "epoch": 4.871361502347418,
952
+ "grad_norm": 3.9085493087768555,
953
+ "learning_rate": 1.2986945169712793e-05,
954
+ "loss": 0.1758,
955
+ "step": 1300
956
+ },
957
+ {
958
+ "epoch": 4.908920187793427,
959
+ "grad_norm": 3.529478073120117,
960
+ "learning_rate": 1.2830287206266318e-05,
961
+ "loss": 0.1549,
962
+ "step": 1310
963
+ },
964
+ {
965
+ "epoch": 4.946478873239436,
966
+ "grad_norm": 2.559157609939575,
967
+ "learning_rate": 1.2673629242819842e-05,
968
+ "loss": 0.1824,
969
+ "step": 1320
970
+ },
971
+ {
972
+ "epoch": 4.984037558685446,
973
+ "grad_norm": 2.2350497245788574,
974
+ "learning_rate": 1.2516971279373368e-05,
975
+ "loss": 0.1723,
976
+ "step": 1330
977
+ },
978
+ {
979
+ "epoch": 5.0,
980
+ "eval_accuracy": 0.9826419535157399,
981
+ "eval_loss": 0.09542840719223022,
982
+ "eval_runtime": 5.0389,
983
+ "eval_samples_per_second": 1349.105,
984
+ "eval_steps_per_second": 42.271,
985
+ "step": 1335
986
+ },
987
+ {
988
+ "epoch": 5.018779342723005,
989
+ "grad_norm": 2.5073907375335693,
990
+ "learning_rate": 1.2360313315926892e-05,
991
+ "loss": 0.1401,
992
+ "step": 1340
993
+ },
994
+ {
995
+ "epoch": 5.056338028169014,
996
+ "grad_norm": 4.696757793426514,
997
+ "learning_rate": 1.2203655352480418e-05,
998
+ "loss": 0.1801,
999
+ "step": 1350
1000
+ },
1001
+ {
1002
+ "epoch": 5.093896713615023,
1003
+ "grad_norm": 1.2180489301681519,
1004
+ "learning_rate": 1.2046997389033942e-05,
1005
+ "loss": 0.1335,
1006
+ "step": 1360
1007
+ },
1008
+ {
1009
+ "epoch": 5.131455399061033,
1010
+ "grad_norm": 0.887860119342804,
1011
+ "learning_rate": 1.1890339425587468e-05,
1012
+ "loss": 0.1479,
1013
+ "step": 1370
1014
+ },
1015
+ {
1016
+ "epoch": 5.169014084507042,
1017
+ "grad_norm": 3.6347432136535645,
1018
+ "learning_rate": 1.1733681462140992e-05,
1019
+ "loss": 0.1575,
1020
+ "step": 1380
1021
+ },
1022
+ {
1023
+ "epoch": 5.206572769953052,
1024
+ "grad_norm": 2.901700496673584,
1025
+ "learning_rate": 1.1577023498694518e-05,
1026
+ "loss": 0.1367,
1027
+ "step": 1390
1028
+ },
1029
+ {
1030
+ "epoch": 5.244131455399061,
1031
+ "grad_norm": 2.6395390033721924,
1032
+ "learning_rate": 1.1420365535248042e-05,
1033
+ "loss": 0.144,
1034
+ "step": 1400
1035
+ },
1036
+ {
1037
+ "epoch": 5.28169014084507,
1038
+ "grad_norm": 3.923652172088623,
1039
+ "learning_rate": 1.1263707571801567e-05,
1040
+ "loss": 0.1576,
1041
+ "step": 1410
1042
+ },
1043
+ {
1044
+ "epoch": 5.31924882629108,
1045
+ "grad_norm": 2.290224313735962,
1046
+ "learning_rate": 1.1107049608355092e-05,
1047
+ "loss": 0.16,
1048
+ "step": 1420
1049
+ },
1050
+ {
1051
+ "epoch": 5.356807511737089,
1052
+ "grad_norm": 2.332317590713501,
1053
+ "learning_rate": 1.0950391644908617e-05,
1054
+ "loss": 0.1505,
1055
+ "step": 1430
1056
+ },
1057
+ {
1058
+ "epoch": 5.394366197183099,
1059
+ "grad_norm": 3.474155902862549,
1060
+ "learning_rate": 1.0793733681462141e-05,
1061
+ "loss": 0.1828,
1062
+ "step": 1440
1063
+ },
1064
+ {
1065
+ "epoch": 5.431924882629108,
1066
+ "grad_norm": 2.5219180583953857,
1067
+ "learning_rate": 1.0637075718015665e-05,
1068
+ "loss": 0.1563,
1069
+ "step": 1450
1070
+ },
1071
+ {
1072
+ "epoch": 5.469483568075117,
1073
+ "grad_norm": 4.863851547241211,
1074
+ "learning_rate": 1.0480417754569191e-05,
1075
+ "loss": 0.1308,
1076
+ "step": 1460
1077
+ },
1078
+ {
1079
+ "epoch": 5.507042253521127,
1080
+ "grad_norm": 4.817688941955566,
1081
+ "learning_rate": 1.0323759791122715e-05,
1082
+ "loss": 0.1757,
1083
+ "step": 1470
1084
+ },
1085
+ {
1086
+ "epoch": 5.544600938967136,
1087
+ "grad_norm": 3.194732189178467,
1088
+ "learning_rate": 1.0167101827676241e-05,
1089
+ "loss": 0.1577,
1090
+ "step": 1480
1091
+ },
1092
+ {
1093
+ "epoch": 5.582159624413146,
1094
+ "grad_norm": 3.6605474948883057,
1095
+ "learning_rate": 1.0010443864229765e-05,
1096
+ "loss": 0.2044,
1097
+ "step": 1490
1098
+ },
1099
+ {
1100
+ "epoch": 5.619718309859155,
1101
+ "grad_norm": 2.427701473236084,
1102
+ "learning_rate": 9.853785900783291e-06,
1103
+ "loss": 0.1574,
1104
+ "step": 1500
1105
+ },
1106
+ {
1107
+ "epoch": 5.657276995305164,
1108
+ "grad_norm": 2.8025519847869873,
1109
+ "learning_rate": 9.697127937336815e-06,
1110
+ "loss": 0.188,
1111
+ "step": 1510
1112
+ },
1113
+ {
1114
+ "epoch": 5.694835680751174,
1115
+ "grad_norm": 2.042407989501953,
1116
+ "learning_rate": 9.54046997389034e-06,
1117
+ "loss": 0.1639,
1118
+ "step": 1520
1119
+ },
1120
+ {
1121
+ "epoch": 5.732394366197183,
1122
+ "grad_norm": 4.5383477210998535,
1123
+ "learning_rate": 9.383812010443865e-06,
1124
+ "loss": 0.1641,
1125
+ "step": 1530
1126
+ },
1127
+ {
1128
+ "epoch": 5.769953051643192,
1129
+ "grad_norm": 2.919588804244995,
1130
+ "learning_rate": 9.22715404699739e-06,
1131
+ "loss": 0.1374,
1132
+ "step": 1540
1133
+ },
1134
+ {
1135
+ "epoch": 5.807511737089202,
1136
+ "grad_norm": 2.4344029426574707,
1137
+ "learning_rate": 9.070496083550915e-06,
1138
+ "loss": 0.1711,
1139
+ "step": 1550
1140
+ },
1141
+ {
1142
+ "epoch": 5.845070422535211,
1143
+ "grad_norm": 1.5614906549453735,
1144
+ "learning_rate": 8.913838120104439e-06,
1145
+ "loss": 0.1624,
1146
+ "step": 1560
1147
+ },
1148
+ {
1149
+ "epoch": 5.882629107981221,
1150
+ "grad_norm": 3.0189967155456543,
1151
+ "learning_rate": 8.757180156657963e-06,
1152
+ "loss": 0.1691,
1153
+ "step": 1570
1154
+ },
1155
+ {
1156
+ "epoch": 5.92018779342723,
1157
+ "grad_norm": 2.44000506401062,
1158
+ "learning_rate": 8.600522193211488e-06,
1159
+ "loss": 0.1513,
1160
+ "step": 1580
1161
+ },
1162
+ {
1163
+ "epoch": 5.957746478873239,
1164
+ "grad_norm": 2.4327423572540283,
1165
+ "learning_rate": 8.443864229765013e-06,
1166
+ "loss": 0.1538,
1167
+ "step": 1590
1168
+ },
1169
+ {
1170
+ "epoch": 5.995305164319249,
1171
+ "grad_norm": 2.1192240715026855,
1172
+ "learning_rate": 8.287206266318538e-06,
1173
+ "loss": 0.1442,
1174
+ "step": 1600
1175
+ },
1176
+ {
1177
+ "epoch": 6.0,
1178
+ "eval_accuracy": 0.981318034716093,
1179
+ "eval_loss": 0.09270217269659042,
1180
+ "eval_runtime": 4.8524,
1181
+ "eval_samples_per_second": 1400.961,
1182
+ "eval_steps_per_second": 43.896,
1183
+ "step": 1602
1184
+ },
1185
+ {
1186
+ "epoch": 6.030046948356808,
1187
+ "grad_norm": 1.8678548336029053,
1188
+ "learning_rate": 8.130548302872062e-06,
1189
+ "loss": 0.1328,
1190
+ "step": 1610
1191
+ },
1192
+ {
1193
+ "epoch": 6.067605633802817,
1194
+ "grad_norm": 3.0712783336639404,
1195
+ "learning_rate": 7.973890339425586e-06,
1196
+ "loss": 0.1543,
1197
+ "step": 1620
1198
+ },
1199
+ {
1200
+ "epoch": 6.105164319248826,
1201
+ "grad_norm": 4.49588680267334,
1202
+ "learning_rate": 7.817232375979112e-06,
1203
+ "loss": 0.1452,
1204
+ "step": 1630
1205
+ },
1206
+ {
1207
+ "epoch": 6.142723004694836,
1208
+ "grad_norm": 3.9594759941101074,
1209
+ "learning_rate": 7.660574412532636e-06,
1210
+ "loss": 0.1513,
1211
+ "step": 1640
1212
+ },
1213
+ {
1214
+ "epoch": 6.180281690140845,
1215
+ "grad_norm": 2.528153657913208,
1216
+ "learning_rate": 7.503916449086162e-06,
1217
+ "loss": 0.1589,
1218
+ "step": 1650
1219
+ },
1220
+ {
1221
+ "epoch": 6.217840375586855,
1222
+ "grad_norm": 2.159458875656128,
1223
+ "learning_rate": 7.347258485639687e-06,
1224
+ "loss": 0.1443,
1225
+ "step": 1660
1226
+ },
1227
+ {
1228
+ "epoch": 6.255399061032864,
1229
+ "grad_norm": 2.098022222518921,
1230
+ "learning_rate": 7.190600522193212e-06,
1231
+ "loss": 0.1564,
1232
+ "step": 1670
1233
+ },
1234
+ {
1235
+ "epoch": 6.292957746478873,
1236
+ "grad_norm": 1.993698239326477,
1237
+ "learning_rate": 7.033942558746737e-06,
1238
+ "loss": 0.1401,
1239
+ "step": 1680
1240
+ },
1241
+ {
1242
+ "epoch": 6.330516431924883,
1243
+ "grad_norm": 2.2639145851135254,
1244
+ "learning_rate": 6.877284595300262e-06,
1245
+ "loss": 0.1452,
1246
+ "step": 1690
1247
+ },
1248
+ {
1249
+ "epoch": 6.368075117370892,
1250
+ "grad_norm": 2.5003936290740967,
1251
+ "learning_rate": 6.720626631853786e-06,
1252
+ "loss": 0.1439,
1253
+ "step": 1700
1254
+ },
1255
+ {
1256
+ "epoch": 6.405633802816902,
1257
+ "grad_norm": 2.0841052532196045,
1258
+ "learning_rate": 6.563968668407311e-06,
1259
+ "loss": 0.1438,
1260
+ "step": 1710
1261
+ },
1262
+ {
1263
+ "epoch": 6.443192488262911,
1264
+ "grad_norm": 3.550182819366455,
1265
+ "learning_rate": 6.4073107049608355e-06,
1266
+ "loss": 0.1433,
1267
+ "step": 1720
1268
+ },
1269
+ {
1270
+ "epoch": 6.48075117370892,
1271
+ "grad_norm": 1.4857251644134521,
1272
+ "learning_rate": 6.2506527415143605e-06,
1273
+ "loss": 0.1404,
1274
+ "step": 1730
1275
+ },
1276
+ {
1277
+ "epoch": 6.51830985915493,
1278
+ "grad_norm": 3.503309726715088,
1279
+ "learning_rate": 6.093994778067885e-06,
1280
+ "loss": 0.1493,
1281
+ "step": 1740
1282
+ },
1283
+ {
1284
+ "epoch": 6.555868544600939,
1285
+ "grad_norm": 3.59545636177063,
1286
+ "learning_rate": 5.93733681462141e-06,
1287
+ "loss": 0.1563,
1288
+ "step": 1750
1289
+ },
1290
+ {
1291
+ "epoch": 6.593427230046949,
1292
+ "grad_norm": 2.879582405090332,
1293
+ "learning_rate": 5.780678851174934e-06,
1294
+ "loss": 0.122,
1295
+ "step": 1760
1296
+ },
1297
+ {
1298
+ "epoch": 6.630985915492958,
1299
+ "grad_norm": 1.7240543365478516,
1300
+ "learning_rate": 5.624020887728459e-06,
1301
+ "loss": 0.1404,
1302
+ "step": 1770
1303
+ },
1304
+ {
1305
+ "epoch": 6.668544600938967,
1306
+ "grad_norm": 3.0438528060913086,
1307
+ "learning_rate": 5.467362924281984e-06,
1308
+ "loss": 0.1432,
1309
+ "step": 1780
1310
+ },
1311
+ {
1312
+ "epoch": 6.706103286384977,
1313
+ "grad_norm": 2.496366024017334,
1314
+ "learning_rate": 5.310704960835509e-06,
1315
+ "loss": 0.1277,
1316
+ "step": 1790
1317
+ },
1318
+ {
1319
+ "epoch": 6.743661971830986,
1320
+ "grad_norm": 1.7166277170181274,
1321
+ "learning_rate": 5.154046997389034e-06,
1322
+ "loss": 0.143,
1323
+ "step": 1800
1324
+ },
1325
+ {
1326
+ "epoch": 6.781220657276995,
1327
+ "grad_norm": 2.4547784328460693,
1328
+ "learning_rate": 4.997389033942559e-06,
1329
+ "loss": 0.1198,
1330
+ "step": 1810
1331
+ },
1332
+ {
1333
+ "epoch": 6.818779342723005,
1334
+ "grad_norm": 2.604220390319824,
1335
+ "learning_rate": 4.840731070496084e-06,
1336
+ "loss": 0.1705,
1337
+ "step": 1820
1338
+ },
1339
+ {
1340
+ "epoch": 6.856338028169014,
1341
+ "grad_norm": 2.7237601280212402,
1342
+ "learning_rate": 4.684073107049609e-06,
1343
+ "loss": 0.1506,
1344
+ "step": 1830
1345
+ },
1346
+ {
1347
+ "epoch": 6.893896713615024,
1348
+ "grad_norm": 2.638058662414551,
1349
+ "learning_rate": 4.527415143603134e-06,
1350
+ "loss": 0.154,
1351
+ "step": 1840
1352
+ },
1353
+ {
1354
+ "epoch": 6.931455399061033,
1355
+ "grad_norm": 3.8382205963134766,
1356
+ "learning_rate": 4.3707571801566586e-06,
1357
+ "loss": 0.1553,
1358
+ "step": 1850
1359
+ },
1360
+ {
1361
+ "epoch": 6.969014084507043,
1362
+ "grad_norm": 2.071164131164551,
1363
+ "learning_rate": 4.2140992167101835e-06,
1364
+ "loss": 0.1397,
1365
+ "step": 1860
1366
+ },
1367
+ {
1368
+ "epoch": 7.0,
1369
+ "eval_accuracy": 0.9811709326272433,
1370
+ "eval_loss": 0.08920056372880936,
1371
+ "eval_runtime": 4.9166,
1372
+ "eval_samples_per_second": 1382.662,
1373
+ "eval_steps_per_second": 43.323,
1374
+ "step": 1869
1375
+ },
1376
+ {
1377
+ "epoch": 7.003755868544601,
1378
+ "grad_norm": 2.5346381664276123,
1379
+ "learning_rate": 4.0574412532637075e-06,
1380
+ "loss": 0.1296,
1381
+ "step": 1870
1382
+ },
1383
+ {
1384
+ "epoch": 7.041314553990611,
1385
+ "grad_norm": 2.575307846069336,
1386
+ "learning_rate": 3.9007832898172325e-06,
1387
+ "loss": 0.1389,
1388
+ "step": 1880
1389
+ },
1390
+ {
1391
+ "epoch": 7.07887323943662,
1392
+ "grad_norm": 2.0408527851104736,
1393
+ "learning_rate": 3.7441253263707574e-06,
1394
+ "loss": 0.1521,
1395
+ "step": 1890
1396
+ },
1397
+ {
1398
+ "epoch": 7.1164319248826295,
1399
+ "grad_norm": 3.2742061614990234,
1400
+ "learning_rate": 3.5874673629242823e-06,
1401
+ "loss": 0.1342,
1402
+ "step": 1900
1403
+ },
1404
+ {
1405
+ "epoch": 7.153990610328639,
1406
+ "grad_norm": 1.4502960443496704,
1407
+ "learning_rate": 3.4308093994778068e-06,
1408
+ "loss": 0.1204,
1409
+ "step": 1910
1410
+ },
1411
+ {
1412
+ "epoch": 7.191549295774648,
1413
+ "grad_norm": 3.7600743770599365,
1414
+ "learning_rate": 3.2741514360313317e-06,
1415
+ "loss": 0.1431,
1416
+ "step": 1920
1417
+ },
1418
+ {
1419
+ "epoch": 7.229107981220658,
1420
+ "grad_norm": 2.7332417964935303,
1421
+ "learning_rate": 3.1174934725848566e-06,
1422
+ "loss": 0.1281,
1423
+ "step": 1930
1424
+ },
1425
+ {
1426
+ "epoch": 7.266666666666667,
1427
+ "grad_norm": 2.6618921756744385,
1428
+ "learning_rate": 2.960835509138381e-06,
1429
+ "loss": 0.141,
1430
+ "step": 1940
1431
+ },
1432
+ {
1433
+ "epoch": 7.304225352112676,
1434
+ "grad_norm": 3.625688314437866,
1435
+ "learning_rate": 2.804177545691906e-06,
1436
+ "loss": 0.1455,
1437
+ "step": 1950
1438
+ },
1439
+ {
1440
+ "epoch": 7.341784037558686,
1441
+ "grad_norm": 2.0667765140533447,
1442
+ "learning_rate": 2.647519582245431e-06,
1443
+ "loss": 0.1359,
1444
+ "step": 1960
1445
+ },
1446
+ {
1447
+ "epoch": 7.379342723004695,
1448
+ "grad_norm": 2.369652509689331,
1449
+ "learning_rate": 2.490861618798956e-06,
1450
+ "loss": 0.1295,
1451
+ "step": 1970
1452
+ },
1453
+ {
1454
+ "epoch": 7.416901408450705,
1455
+ "grad_norm": 3.836838722229004,
1456
+ "learning_rate": 2.3342036553524807e-06,
1457
+ "loss": 0.1489,
1458
+ "step": 1980
1459
+ },
1460
+ {
1461
+ "epoch": 7.454460093896714,
1462
+ "grad_norm": 3.3261311054229736,
1463
+ "learning_rate": 2.1775456919060052e-06,
1464
+ "loss": 0.1289,
1465
+ "step": 1990
1466
+ },
1467
+ {
1468
+ "epoch": 7.492018779342723,
1469
+ "grad_norm": 2.6514954566955566,
1470
+ "learning_rate": 2.0208877284595297e-06,
1471
+ "loss": 0.1185,
1472
+ "step": 2000
1473
+ },
1474
+ {
1475
+ "epoch": 7.529577464788733,
1476
+ "grad_norm": 2.1017005443573,
1477
+ "learning_rate": 1.8642297650130548e-06,
1478
+ "loss": 0.1472,
1479
+ "step": 2010
1480
+ },
1481
+ {
1482
+ "epoch": 7.567136150234742,
1483
+ "grad_norm": 2.5104258060455322,
1484
+ "learning_rate": 1.7075718015665795e-06,
1485
+ "loss": 0.1467,
1486
+ "step": 2020
1487
+ },
1488
+ {
1489
+ "epoch": 7.6046948356807516,
1490
+ "grad_norm": 1.7915935516357422,
1491
+ "learning_rate": 1.5509138381201045e-06,
1492
+ "loss": 0.1212,
1493
+ "step": 2030
1494
+ },
1495
+ {
1496
+ "epoch": 7.642253521126761,
1497
+ "grad_norm": 2.4937989711761475,
1498
+ "learning_rate": 1.3942558746736294e-06,
1499
+ "loss": 0.1395,
1500
+ "step": 2040
1501
+ },
1502
+ {
1503
+ "epoch": 7.67981220657277,
1504
+ "grad_norm": 2.758594274520874,
1505
+ "learning_rate": 1.237597911227154e-06,
1506
+ "loss": 0.1361,
1507
+ "step": 2050
1508
+ },
1509
+ {
1510
+ "epoch": 7.71737089201878,
1511
+ "grad_norm": 2.291672468185425,
1512
+ "learning_rate": 1.0809399477806788e-06,
1513
+ "loss": 0.1182,
1514
+ "step": 2060
1515
+ },
1516
+ {
1517
+ "epoch": 7.754929577464789,
1518
+ "grad_norm": 1.944736361503601,
1519
+ "learning_rate": 9.242819843342037e-07,
1520
+ "loss": 0.1307,
1521
+ "step": 2070
1522
+ },
1523
+ {
1524
+ "epoch": 7.792488262910798,
1525
+ "grad_norm": 1.448411226272583,
1526
+ "learning_rate": 7.676240208877285e-07,
1527
+ "loss": 0.1407,
1528
+ "step": 2080
1529
+ },
1530
+ {
1531
+ "epoch": 7.830046948356808,
1532
+ "grad_norm": 3.276000499725342,
1533
+ "learning_rate": 6.109660574412533e-07,
1534
+ "loss": 0.1361,
1535
+ "step": 2090
1536
+ },
1537
+ {
1538
+ "epoch": 7.867605633802817,
1539
+ "grad_norm": 3.627788543701172,
1540
+ "learning_rate": 4.5430809399477806e-07,
1541
+ "loss": 0.131,
1542
+ "step": 2100
1543
+ },
1544
+ {
1545
+ "epoch": 7.905164319248827,
1546
+ "grad_norm": 1.2533661127090454,
1547
+ "learning_rate": 2.9765013054830287e-07,
1548
+ "loss": 0.1245,
1549
+ "step": 2110
1550
+ },
1551
+ {
1552
+ "epoch": 7.942723004694836,
1553
+ "grad_norm": 1.472484827041626,
1554
+ "learning_rate": 1.409921671018277e-07,
1555
+ "loss": 0.1368,
1556
+ "step": 2120
1557
+ },
1558
+ {
1559
+ "epoch": 7.972769953051643,
1560
+ "eval_accuracy": 0.9811709326272433,
1561
+ "eval_loss": 0.08957477658987045,
1562
+ "eval_runtime": 5.418,
1563
+ "eval_samples_per_second": 1254.697,
1564
+ "eval_steps_per_second": 39.313,
1565
+ "step": 2128
1566
+ },
1567
+ {
1568
+ "epoch": 7.972769953051643,
1569
+ "step": 2128,
1570
+ "total_flos": 3.767900833756416e+18,
1571
+ "train_loss": 0.5178930132572812,
1572
+ "train_runtime": 756.2923,
1573
+ "train_samples_per_second": 540.468,
1574
+ "train_steps_per_second": 2.814
1575
+ }
1576
+ ],
1577
+ "logging_steps": 10,
1578
+ "max_steps": 2128,
1579
+ "num_input_tokens_seen": 0,
1580
+ "num_train_epochs": 8,
1581
+ "save_steps": 500,
1582
+ "stateful_callbacks": {
1583
+ "TrainerControl": {
1584
+ "args": {
1585
+ "should_epoch_stop": false,
1586
+ "should_evaluate": false,
1587
+ "should_log": false,
1588
+ "should_save": true,
1589
+ "should_training_stop": true
1590
+ },
1591
+ "attributes": {}
1592
+ }
1593
+ },
1594
+ "total_flos": 3.767900833756416e+18,
1595
+ "train_batch_size": 48,
1596
+ "trial_name": null,
1597
+ "trial_params": null
1598
+ }
hpo-examples/audio-classification/ac/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11e71a8faf8833e9bfb138217263fb0518314e3b8597902b752b2bc9dd143942
3
+ size 5368
hpo-examples/audio-classification/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ datasets>=1.14.0
2
+ evaluate
3
+ librosa
4
+ torchaudio
5
+ torch>=1.6
hpo-examples/audio-classification/run.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=0 python run_audio_classification.py \
2
+ --model_name_or_path facebook/wav2vec2-base \
3
+ --dataset_name superb \
4
+ --dataset_config_name ks \
5
+ --trust_remote_code \
6
+ --output_dir wav2vec2-base-ft-keyword-spotting \
7
+ --overwrite_output_dir \
8
+ --remove_unused_columns False \
9
+ --do_train \
10
+ --do_eval \
11
+ --fp16 \
12
+ --learning_rate 3e-5 \
13
+ --max_length_seconds 1 \
14
+ --attention_mask False \
15
+ --warmup_ratio 0.1 \
16
+ --num_train_epochs 8 \
17
+ --per_device_train_batch_size 64 \
18
+ --gradient_accumulation_steps 4 \
19
+ --per_device_eval_batch_size 32 \
20
+ --dataloader_num_workers 4 \
21
+ --logging_strategy steps \
22
+ --logging_steps 10 \
23
+ --eval_strategy epoch \
24
+ --save_strategy epoch \
25
+ --load_best_model_at_end True \
26
+ --metric_for_best_model accuracy \
27
+ --save_total_limit 3 \
28
+ --seed 0 \
29
+ --push_to_hub \
30
+ --apply-trp --trp-depths 1 --trp-p 0.1 --trp-lambdas 0.4 0.2 0.1
hpo-examples/audio-classification/run_audio_classification.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import logging
18
+ import os
19
+ import sys
20
+ import warnings
21
+ from dataclasses import dataclass, field
22
+ from random import randint
23
+ from typing import Optional, List
24
+
25
+ import datasets
26
+ import evaluate
27
+ import numpy as np
28
+ from datasets import DatasetDict, load_dataset
29
+
30
+ import transformers
31
+ from transformers import (
32
+ AutoConfig,
33
+ AutoFeatureExtractor,
34
+ AutoModelForAudioClassification,
35
+ HfArgumentParser,
36
+ Trainer,
37
+ TrainingArguments,
38
+ set_seed,
39
+ )
40
+ from transformers.trainer_utils import get_last_checkpoint
41
+ from transformers.utils import check_min_version, send_example_telemetry
42
+ from transformers.utils.versions import require_version
43
+
44
+
45
+ from trplib import apply_trp
46
+
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+ # # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
51
+ # check_min_version("4.50.0.dev0")
52
+
53
+ require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
54
+
55
+
56
+ def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 16000):
57
+ """Randomly sample chunks of `max_length` seconds from the input audio"""
58
+ sample_length = int(round(sample_rate * max_length))
59
+ if len(wav) <= sample_length:
60
+ return wav
61
+ random_offset = randint(0, len(wav) - sample_length - 1)
62
+ return wav[random_offset : random_offset + sample_length]
63
+
64
+
65
+ @dataclass
66
+ class DataTrainingArguments:
67
+ """
68
+ Arguments pertaining to what data we are going to input our model for training and eval.
69
+ Using `HfArgumentParser` we can turn this class
70
+ into argparse arguments to be able to specify them on
71
+ the command line.
72
+ """
73
+
74
+ dataset_name: Optional[str] = field(default=None, metadata={"help": "Name of a dataset from the datasets package"})
75
+ dataset_config_name: Optional[str] = field(
76
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
77
+ )
78
+ train_file: Optional[str] = field(
79
+ default=None, metadata={"help": "A file containing the training audio paths and labels."}
80
+ )
81
+ eval_file: Optional[str] = field(
82
+ default=None, metadata={"help": "A file containing the validation audio paths and labels."}
83
+ )
84
+ train_split_name: str = field(
85
+ default="train",
86
+ metadata={
87
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
88
+ },
89
+ )
90
+ eval_split_name: str = field(
91
+ default="validation",
92
+ metadata={
93
+ "help": (
94
+ "The name of the training data set split to use (via the datasets library). Defaults to 'validation'"
95
+ )
96
+ },
97
+ )
98
+ audio_column_name: str = field(
99
+ default="audio",
100
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
101
+ )
102
+ label_column_name: str = field(
103
+ default="label", metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'"}
104
+ )
105
+ max_train_samples: Optional[int] = field(
106
+ default=None,
107
+ metadata={
108
+ "help": (
109
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
110
+ "value if set."
111
+ )
112
+ },
113
+ )
114
+ max_eval_samples: Optional[int] = field(
115
+ default=None,
116
+ metadata={
117
+ "help": (
118
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
119
+ "value if set."
120
+ )
121
+ },
122
+ )
123
+ max_length_seconds: float = field(
124
+ default=20,
125
+ metadata={"help": "Audio clips will be randomly cut to this length during training if the value is set."},
126
+ )
127
+
128
+
129
+ @dataclass
130
+ class ModelArguments:
131
+ """
132
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
133
+ """
134
+
135
+ model_name_or_path: str = field(
136
+ default="facebook/wav2vec2-base",
137
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"},
138
+ )
139
+ config_name: Optional[str] = field(
140
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
141
+ )
142
+ cache_dir: Optional[str] = field(
143
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from the Hub"}
144
+ )
145
+ model_revision: str = field(
146
+ default="main",
147
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
148
+ )
149
+ feature_extractor_name: Optional[str] = field(
150
+ default=None, metadata={"help": "Name or path of preprocessor config."}
151
+ )
152
+ freeze_feature_encoder: bool = field(
153
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
154
+ )
155
+ attention_mask: bool = field(
156
+ default=True, metadata={"help": "Whether to generate an attention mask in the feature extractor."}
157
+ )
158
+ token: str = field(
159
+ default=None,
160
+ metadata={
161
+ "help": (
162
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
163
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
164
+ )
165
+ },
166
+ )
167
+ trust_remote_code: bool = field(
168
+ default=False,
169
+ metadata={
170
+ "help": (
171
+ "Whether to trust the execution of code from datasets/models defined on the Hub."
172
+ " This option should only be set to `True` for repositories you trust and in which you have read the"
173
+ " code, as it will execute code present on the Hub on your local machine."
174
+ )
175
+ },
176
+ )
177
+ freeze_feature_extractor: Optional[bool] = field(
178
+ default=None, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
179
+ )
180
+ ignore_mismatched_sizes: bool = field(
181
+ default=False,
182
+ metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
183
+ )
184
+
185
+ apply_trp: Optional[bool] = field(
186
+ default=False,
187
+ metadata={"help": "Whether to apply trp or not."},
188
+ )
189
+ trp_depths: Optional[int] = field(
190
+ default=1,
191
+ metadata={
192
+ "help": "TRP depth value."
193
+ },
194
+ )
195
+ trp_p: Optional[float] = field(
196
+ default=0.1,
197
+ metadata={
198
+ "help": "TRP p value."
199
+ },
200
+ )
201
+ trp_lambdas: Optional[List[float]] = field(
202
+ default_factory=lambda: [0.4, 0.2, 0.1],
203
+ metadata={
204
+ "help": "TRP lambda values (list of floats)."
205
+ },
206
+ )
207
+
208
+ def __post_init__(self):
209
+ if not self.freeze_feature_extractor and self.freeze_feature_encoder:
210
+ warnings.warn(
211
+ "The argument `--freeze_feature_extractor` is deprecated and "
212
+ "will be removed in a future version. Use `--freeze_feature_encoder` "
213
+ "instead. Setting `freeze_feature_encoder==True`.",
214
+ FutureWarning,
215
+ )
216
+ if self.freeze_feature_extractor and not self.freeze_feature_encoder:
217
+ raise ValueError(
218
+ "The argument `--freeze_feature_extractor` is deprecated and "
219
+ "should not be used in combination with `--freeze_feature_encoder`. "
220
+ "Only make use of `--freeze_feature_encoder`."
221
+ )
222
+
223
+
224
+ def main():
225
+ # See all possible arguments in src/transformers/training_args.py
226
+ # or by passing the --help flag to this script.
227
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
228
+
229
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
230
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
231
+ # If we pass only one argument to the script and it's the path to a json file,
232
+ # let's parse it to get our arguments.
233
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
234
+ else:
235
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
236
+
237
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
238
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
239
+ send_example_telemetry("run_audio_classification", model_args, data_args)
240
+
241
+ # Setup logging
242
+ logging.basicConfig(
243
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
244
+ datefmt="%m/%d/%Y %H:%M:%S",
245
+ handlers=[logging.StreamHandler(sys.stdout)],
246
+ )
247
+
248
+ if training_args.should_log:
249
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
250
+ transformers.utils.logging.set_verbosity_info()
251
+
252
+ log_level = training_args.get_process_log_level()
253
+ logger.setLevel(log_level)
254
+ transformers.utils.logging.set_verbosity(log_level)
255
+ transformers.utils.logging.enable_default_handler()
256
+ transformers.utils.logging.enable_explicit_format()
257
+
258
+ # Log on each process the small summary:
259
+ logger.warning(
260
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
261
+ + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
262
+ )
263
+ logger.info(f"Training/evaluation parameters {training_args}")
264
+
265
+ # Set seed before initializing model.
266
+ set_seed(training_args.seed)
267
+
268
+ # Detecting last checkpoint.
269
+ last_checkpoint = None
270
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
271
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
272
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
273
+ raise ValueError(
274
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
275
+ "Use --overwrite_output_dir to train from scratch."
276
+ )
277
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
278
+ logger.info(
279
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
280
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
281
+ )
282
+
283
+ # Initialize our dataset and prepare it for the audio classification task.
284
+ raw_datasets = DatasetDict()
285
+ raw_datasets["train"] = load_dataset(
286
+ data_args.dataset_name,
287
+ data_args.dataset_config_name,
288
+ split=data_args.train_split_name,
289
+ token=model_args.token,
290
+ trust_remote_code=model_args.trust_remote_code,
291
+ )
292
+ raw_datasets["eval"] = load_dataset(
293
+ data_args.dataset_name,
294
+ data_args.dataset_config_name,
295
+ split=data_args.eval_split_name,
296
+ token=model_args.token,
297
+ trust_remote_code=model_args.trust_remote_code,
298
+ )
299
+
300
+ if data_args.audio_column_name not in raw_datasets["train"].column_names:
301
+ raise ValueError(
302
+ f"--audio_column_name {data_args.audio_column_name} not found in dataset '{data_args.dataset_name}'. "
303
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
304
+ f"{', '.join(raw_datasets['train'].column_names)}."
305
+ )
306
+
307
+ if data_args.label_column_name not in raw_datasets["train"].column_names:
308
+ raise ValueError(
309
+ f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. "
310
+ "Make sure to set `--label_column_name` to the correct text column - one of "
311
+ f"{', '.join(raw_datasets['train'].column_names)}."
312
+ )
313
+
314
+ # Setting `return_attention_mask=True` is the way to get a correctly masked mean-pooling over
315
+ # transformer outputs in the classifier, but it doesn't always lead to better accuracy
316
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
317
+ model_args.feature_extractor_name or model_args.model_name_or_path,
318
+ return_attention_mask=model_args.attention_mask,
319
+ cache_dir=model_args.cache_dir,
320
+ revision=model_args.model_revision,
321
+ token=model_args.token,
322
+ trust_remote_code=model_args.trust_remote_code,
323
+ )
324
+
325
+ # `datasets` takes care of automatically loading and resampling the audio,
326
+ # so we just need to set the correct target sampling rate.
327
+ raw_datasets = raw_datasets.cast_column(
328
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
329
+ )
330
+
331
+ model_input_name = feature_extractor.model_input_names[0]
332
+
333
+ def train_transforms(batch):
334
+ """Apply train_transforms across a batch."""
335
+ subsampled_wavs = []
336
+ for audio in batch[data_args.audio_column_name]:
337
+ wav = random_subsample(
338
+ audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
339
+ )
340
+ subsampled_wavs.append(wav)
341
+ inputs = feature_extractor(subsampled_wavs, sampling_rate=feature_extractor.sampling_rate)
342
+ output_batch = {model_input_name: inputs.get(model_input_name)}
343
+ output_batch["labels"] = list(batch[data_args.label_column_name])
344
+
345
+ return output_batch
346
+
347
+ def val_transforms(batch):
348
+ """Apply val_transforms across a batch."""
349
+ wavs = [audio["array"] for audio in batch[data_args.audio_column_name]]
350
+ inputs = feature_extractor(wavs, sampling_rate=feature_extractor.sampling_rate)
351
+ output_batch = {model_input_name: inputs.get(model_input_name)}
352
+ output_batch["labels"] = list(batch[data_args.label_column_name])
353
+
354
+ return output_batch
355
+
356
+ # Prepare label mappings.
357
+ # We'll include these in the model's config to get human readable labels in the Inference API.
358
+ labels = raw_datasets["train"].features[data_args.label_column_name].names
359
+ label2id, id2label = {}, {}
360
+ for i, label in enumerate(labels):
361
+ label2id[label] = str(i)
362
+ id2label[str(i)] = label
363
+
364
+ # Load the accuracy metric from the datasets package
365
+ metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
366
+
367
+ # Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with
368
+ # `predictions` and `label_ids` fields) and has to return a dictionary string to float.
369
+ def compute_metrics(eval_pred):
370
+ """Computes accuracy on a batch of predictions"""
371
+ predictions = np.argmax(eval_pred.predictions, axis=1)
372
+ return metric.compute(predictions=predictions, references=eval_pred.label_ids)
373
+
374
+ config = AutoConfig.from_pretrained(
375
+ model_args.config_name or model_args.model_name_or_path,
376
+ num_labels=len(labels),
377
+ label2id=label2id,
378
+ id2label=id2label,
379
+ finetuning_task="audio-classification",
380
+ cache_dir=model_args.cache_dir,
381
+ revision=model_args.model_revision,
382
+ token=model_args.token,
383
+ trust_remote_code=model_args.trust_remote_code,
384
+ )
385
+ model = AutoModelForAudioClassification.from_pretrained(
386
+ model_args.model_name_or_path,
387
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
388
+ config=config,
389
+ cache_dir=model_args.cache_dir,
390
+ revision=model_args.model_revision,
391
+ token=model_args.token,
392
+ trust_remote_code=model_args.trust_remote_code,
393
+ ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
394
+ )
395
+
396
+ # freeze the convolutional waveform encoder
397
+ if model_args.freeze_feature_encoder:
398
+ model.freeze_feature_encoder()
399
+
400
+ if model_args.apply_trp:
401
+ model = apply_trp(model, depths=model_args.trp_depths, p=model_args.trp_p, lambdas=model_args.trp_lambdas)
402
+
403
+ if training_args.do_train:
404
+ if data_args.max_train_samples is not None:
405
+ raw_datasets["train"] = (
406
+ raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
407
+ )
408
+ # Set the training transforms
409
+ raw_datasets["train"].set_transform(train_transforms, output_all_columns=False)
410
+
411
+ if training_args.do_eval:
412
+ if data_args.max_eval_samples is not None:
413
+ raw_datasets["eval"] = (
414
+ raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
415
+ )
416
+ # Set the validation transforms
417
+ raw_datasets["eval"].set_transform(val_transforms, output_all_columns=False)
418
+
419
+ # Initialize our trainer
420
+ trainer = Trainer(
421
+ model=model,
422
+ args=training_args,
423
+ train_dataset=raw_datasets["train"] if training_args.do_train else None,
424
+ eval_dataset=raw_datasets["eval"] if training_args.do_eval else None,
425
+ compute_metrics=compute_metrics,
426
+ processing_class=feature_extractor,
427
+ )
428
+
429
+ # Training
430
+ if training_args.do_train:
431
+ checkpoint = None
432
+ if training_args.resume_from_checkpoint is not None:
433
+ checkpoint = training_args.resume_from_checkpoint
434
+ elif last_checkpoint is not None:
435
+ checkpoint = last_checkpoint
436
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
437
+ trainer.save_model()
438
+ trainer.log_metrics("train", train_result.metrics)
439
+ trainer.save_metrics("train", train_result.metrics)
440
+ trainer.save_state()
441
+
442
+ # Evaluation
443
+ if training_args.do_eval:
444
+ metrics = trainer.evaluate()
445
+ trainer.log_metrics("eval", metrics)
446
+ trainer.save_metrics("eval", metrics)
447
+
448
+ # Write model card and (optionally) push to hub
449
+ kwargs = {
450
+ "finetuned_from": model_args.model_name_or_path,
451
+ "tasks": "audio-classification",
452
+ "dataset": data_args.dataset_name,
453
+ "tags": ["audio-classification"],
454
+ }
455
+ if training_args.push_to_hub:
456
+ trainer.push_to_hub(**kwargs)
457
+ else:
458
+ trainer.create_model_card(**kwargs)
459
+
460
+
461
+ if __name__ == "__main__":
462
+ main()
hpo-examples/audio-classification/trplib.py ADDED
@@ -0,0 +1,1181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ from torch.nn import functional as F
4
+
5
+ from torchvision.models.mobilenetv2 import MobileNetV2
6
+ from torchvision.models.resnet import ResNet
7
+ from torchvision.models.efficientnet import EfficientNet
8
+ from torchvision.models.vision_transformer import VisionTransformer
9
+ from torchvision.models.segmentation.fcn import FCN
10
+ from torchvision.models.segmentation.deeplabv3 import DeepLabV3
11
+
12
+ import transformers
13
+ from transformers.modeling_outputs import SequenceClassifierOutput, QuestionAnsweringModelOutput, CausalLMOutput, Seq2SeqLMOutput
14
+
15
+ from typing import Optional, Tuple, List, Union, Callable
16
+ from collections import OrderedDict
17
+ import types
18
+
19
+
20
+ def trp_criterion(trp_blocks: nn.ModuleList, shared_head: Callable, criterion: Callable, lambdas: List[float], hidden_states: Tensor, logits: Tensor, targets: Tensor, loss_normalization=False):
21
+ loss, mask = criterion(logits, targets)
22
+ if loss_normalization:
23
+ coeff = loss.detach()
24
+
25
+ embeds = [hidden_states]
26
+ predictions = []
27
+ for k, c in enumerate(lambdas):
28
+ embeds.append(trp_blocks[k](embeds[-1]))
29
+ predictions.append(shared_head(embeds[-1]))
30
+ replica_loss, mask = criterion(predictions[-1], targets, mask)
31
+ loss += c * replica_loss
32
+
33
+ if loss_normalization:
34
+ with torch.no_grad():
35
+ coeff = torch.exp(coeff) / torch.exp(loss.detach())
36
+ loss = coeff * loss
37
+
38
+ return loss
39
+
40
+
41
+ class TPBlock(nn.Module):
42
+ def __init__(self, depths: int, in_features: int, p: float, dim=-1):
43
+ super(TPBlock, self).__init__()
44
+
45
+ self.dropout = nn.Dropout(p)
46
+
47
+ self.cdim = dim
48
+
49
+ blocks = []
50
+ for _ in range(depths):
51
+ blocks.append(nn.Linear(in_features, in_features))
52
+ nn.init.constant_(blocks[-1].weight, 0.0)
53
+ nn.init.constant_(blocks[-1].bias, 0.0)
54
+ blocks.append(nn.ReLU())
55
+ self.blocks = nn.Sequential(*blocks)
56
+
57
+ def forward(self, x):
58
+ x = self.dropout(x)
59
+ if self.cdim == -1:
60
+ x = x + self.blocks(x)
61
+ else:
62
+ x = x + torch.movedim(self.blocks(torch.movedim(x, self.cdim, -1)), -1, self.cdim)
63
+ return x
64
+
65
+
66
+ class Config:
67
+ @staticmethod
68
+ def gen_criterion(*args, **kwargs):
69
+ def func(input, target, mask=None):
70
+ """
71
+ Args:
72
+ input (Tensor): Input tensor.
73
+ target (Tensor): Target labels.
74
+
75
+ Returns:
76
+ loss (Tensor): Scalar tensor representing the loss.
77
+ mask (Tensor): Boolean mask tensor with the same shape of target.
78
+ """
79
+ pass
80
+ return func
81
+
82
+ @staticmethod
83
+ def gen_shared_head(*args, **kwargs):
84
+ def func(hidden_states):
85
+ """
86
+ Args:
87
+ hidden_states (Tensor): Hidden States tensor.
88
+
89
+ Returns:
90
+ logits (Tensor): Logits tensor.
91
+ """
92
+ pass
93
+ return func
94
+
95
+ @staticmethod
96
+ def forward(*args, **kwargs):
97
+ pass
98
+
99
+
100
+ # Wav2Vec2 for Audio Classification
101
+ class Wav2Vec2ForSequenceClassificationConfig(Config):
102
+ _HIDDEN_STATES_START_POSITION = 2
103
+
104
+ @staticmethod
105
+ def gen_criterion():
106
+ def func(input, target, mask=None):
107
+ """
108
+ Args:
109
+ input (Tensor): Input tensor of shape [B, C].
110
+ target (Tensor): Target labels of shape [B].
111
+
112
+ Returns:
113
+ loss (Tensor): Scalar tensor representing the loss.
114
+ mask (Tensor): Boolean mask tensor of shape [B].
115
+ """
116
+ if mask is None:
117
+ mask = torch.ones_like(target, dtype=torch.float32, device=target.device)
118
+
119
+ unmasked_loss = F.cross_entropy(input, target, reduction="none")
120
+ loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
121
+
122
+ with torch.no_grad():
123
+ mask = mask * torch.eq(torch.argmax(input, dim=1), target).to(input.dtype)
124
+
125
+ return loss, mask
126
+ return func
127
+
128
+ @staticmethod
129
+ def gen_shared_head(self, attention_mask):
130
+ def func(hidden_states):
131
+ """
132
+ Args:
133
+ hidden_states (Tensor): Hidden States of shape [B, L, hidden_units].
134
+
135
+ Returns:
136
+ logits (Tensor): Logits tensor of shape [B, C].
137
+ """
138
+ _hidden_states = self.projector(hidden_states)
139
+ if attention_mask is None:
140
+ pooled_output = _hidden_states.mean(dim=1)
141
+ else:
142
+ padding_mask = self._get_feature_vector_attention_mask(_hidden_states.shape[1], attention_mask)
143
+ expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, _hidden_states.shape[2])
144
+ _hidden_states[~expand_padding_mask] = 0.0
145
+ pooled_output = _hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
146
+
147
+ logits = self.classifier(pooled_output)
148
+ return logits
149
+ return func
150
+
151
+ @staticmethod
152
+ def gen_forward(lambdas, loss_normalization=False):
153
+ def func(
154
+ self,
155
+ input_values: Optional[torch.Tensor],
156
+ attention_mask: Optional[torch.Tensor] = None,
157
+ output_attentions: Optional[bool] = None,
158
+ output_hidden_states: Optional[bool] = None,
159
+ return_dict: Optional[bool] = None,
160
+ labels: Optional[torch.Tensor] = None,
161
+ ) -> Union[Tuple, SequenceClassifierOutput]:
162
+ r"""
163
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
164
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
165
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
166
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
167
+ """
168
+
169
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
170
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
171
+
172
+ outputs = self.wav2vec2(
173
+ input_values,
174
+ attention_mask=attention_mask,
175
+ output_attentions=output_attentions,
176
+ output_hidden_states=output_hidden_states,
177
+ return_dict=return_dict,
178
+ )
179
+
180
+ if self.config.use_weighted_layer_sum:
181
+ hidden_states = outputs[Wav2Vec2ForSequenceClassificationConfig._HIDDEN_STATES_START_POSITION]
182
+ hidden_states = torch.stack(hidden_states, dim=1)
183
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
184
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
185
+ else:
186
+ hidden_states = outputs[0]
187
+
188
+ _hidden_states = self.projector(hidden_states)
189
+ if attention_mask is None:
190
+ pooled_output = _hidden_states.mean(dim=1)
191
+ else:
192
+ padding_mask = self._get_feature_vector_attention_mask(_hidden_states.shape[1], attention_mask)
193
+ expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, _hidden_states.shape[2])
194
+ _hidden_states[~expand_padding_mask] = 0.0
195
+ pooled_output = _hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
196
+
197
+ logits = self.classifier(pooled_output)
198
+
199
+ loss = None
200
+ if labels is not None:
201
+ shared_head = Wav2Vec2ForSequenceClassificationConfig.gen_shared_head(self, attention_mask)
202
+ criterion = Wav2Vec2ForSequenceClassificationConfig.gen_criterion()
203
+ loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_states, logits.view(-1, self.config.num_labels), labels.view(-1), loss_normalization) # NOTE: Apply TRP!
204
+
205
+ if not return_dict:
206
+ output = (logits,) + outputs[Wav2Vec2ForSequenceClassificationConfig._HIDDEN_STATES_START_POSITION:]
207
+ return ((loss,) + output) if loss is not None else output
208
+
209
+ return SequenceClassifierOutput(
210
+ loss=loss,
211
+ logits=logits,
212
+ hidden_states=outputs.hidden_states,
213
+ attentions=outputs.attentions,
214
+ )
215
+ return func
216
+
217
+
218
+ # MobileNetV2 for Image Classification
219
+ class MobileNetV2Config(Config):
220
+ @staticmethod
221
+ def gen_criterion(label_smoothing=0.0, top_k=1):
222
+ def func(input, target, mask=None):
223
+ """
224
+ Args:
225
+ input (Tensor): Input tensor of shape [B, C].
226
+ target (Tensor): Target labels of shape [B] or [B, C].
227
+
228
+ Returns:
229
+ loss (Tensor): Scalar tensor representing the loss.
230
+ mask (Tensor): Boolean mask tensor of shape [B].
231
+ """
232
+ label = torch.argmax(target, dim=1) if label_smoothing > 0.0 else target
233
+
234
+ unmasked_loss = F.cross_entropy(input, label, reduction="none", label_smoothing=label_smoothing)
235
+ if mask is None:
236
+ mask = torch.ones_like(unmasked_loss, dtype=torch.float32, device=target.device)
237
+ loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
238
+
239
+ with torch.no_grad():
240
+ topk_values, topk_indices = torch.topk(input, top_k, dim=-1)
241
+ mask = mask * torch.eq(topk_indices, label[:, None]).any(dim=-1).to(input.dtype)
242
+
243
+ return loss, mask
244
+ return func
245
+
246
+ @staticmethod
247
+ def gen_shared_head(self):
248
+ def func(x):
249
+ """
250
+ Args:
251
+ x (Tensor): Hidden States tensor of shape [B, hidden_units].
252
+
253
+ Returns:
254
+ logits (Tensor): Logits tensor of shape [B, C].
255
+ """
256
+ logits = self.classifier(x)
257
+ return logits
258
+ return func
259
+
260
+ @staticmethod
261
+ def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
262
+ def func(self, images: Tensor, targets=None):
263
+ x = self.features(images)
264
+ x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
265
+ x = torch.flatten(x, 1)
266
+ logits = self.classifier(x)
267
+
268
+ if self.training:
269
+ torch._assert(targets is not None, "targets should not be none when in training mode")
270
+ shared_head = MobileNetV2Config.gen_shared_head(self)
271
+ criterion = MobileNetV2Config.gen_criterion(label_smoothing, top_k)
272
+ loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
273
+ return logits, loss
274
+ return logits
275
+ return func
276
+
277
+
278
+ # ResNet for Image Classification
279
+ class ResNetConfig(MobileNetV2Config):
280
+ @staticmethod
281
+ def gen_shared_head(self):
282
+ def func(x):
283
+ """
284
+ Args:
285
+ x (Tensor): Hidden States tensor of shape [B, hidden_units].
286
+
287
+ Returns:
288
+ logits (Tensor): Logits tensor of shape [B, C].
289
+ """
290
+ logits = self.fc(x)
291
+ return logits
292
+ return func
293
+
294
+ @staticmethod
295
+ def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
296
+ def func(self, images: Tensor, targets=None):
297
+ x = self.conv1(images)
298
+ x = self.bn1(x)
299
+ x = self.relu(x)
300
+ x = self.maxpool(x)
301
+
302
+ x = self.layer1(x)
303
+ x = self.layer2(x)
304
+ x = self.layer3(x)
305
+ x = self.layer4(x)
306
+
307
+ x = self.avgpool(x)
308
+ x = torch.flatten(x, 1)
309
+ logits = self.fc(x)
310
+
311
+ if self.training:
312
+ torch._assert(targets is not None, "targets should not be none when in training mode")
313
+ shared_head = ResNetConfig.gen_shared_head(self)
314
+ criterion = ResNetConfig.gen_criterion(label_smoothing, top_k)
315
+ loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
316
+ return logits, loss
317
+ return logits
318
+ return func
319
+
320
+
321
+ # EfficientNet for Image Classification
322
+ class EfficientNetConfig(MobileNetV2Config):
323
+ @staticmethod
324
+ def gen_shared_head(self):
325
+ def func(x):
326
+ """
327
+ Args:
328
+ x (Tensor): Hidden States tensor of shape [B, hidden_units].
329
+
330
+ Returns:
331
+ logits (Tensor): Logits tensor of shape [B, C].
332
+ """
333
+ logits = self.classifier(x)
334
+ return logits
335
+ return func
336
+
337
+ @staticmethod
338
+ def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
339
+ def func(self, images: Tensor, targets=None):
340
+ x = self.features(images)
341
+ x = self.avgpool(x)
342
+ x = torch.flatten(x, 1)
343
+ logits = self.classifier(x)
344
+
345
+ if self.training:
346
+ torch._assert(targets is not None, "targets should not be none when in training mode")
347
+ shared_head = EfficientNetConfig.gen_shared_head(self)
348
+ criterion = EfficientNetConfig.gen_criterion(label_smoothing, top_k)
349
+ loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
350
+ return logits, loss
351
+ return logits
352
+ return func
353
+
354
+
355
+ # ViT for Image Classification
356
+ class VisionTransformerConfig(MobileNetV2Config):
357
+ @staticmethod
358
+ def gen_shared_head(self):
359
+ def func(x):
360
+ """
361
+ Args:
362
+ x (Tensor): Hidden States tensor of shape [B, hidden_units].
363
+
364
+ Returns:
365
+ logits (Tensor): Logits tensor of shape [B, C].
366
+ """
367
+ logits = self.heads(x)
368
+ return logits
369
+ return func
370
+
371
+ @staticmethod
372
+ def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
373
+ def func(self, images: Tensor, targets=None):
374
+ x = self._process_input(images)
375
+ n = x.shape[0]
376
+ batch_class_token = self.class_token.expand(n, -1, -1)
377
+ x = torch.cat([batch_class_token, x], dim=1)
378
+ x = self.encoder(x)
379
+ x = x[:, 0]
380
+
381
+ logits = self.heads(x)
382
+
383
+ if self.training:
384
+ torch._assert(targets is not None, "targets should not be none when in training mode")
385
+ shared_head = VisionTransformerConfig.gen_shared_head(self)
386
+ criterion = VisionTransformerConfig.gen_criterion(label_smoothing, top_k)
387
+ loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
388
+ return logits, loss
389
+ return logits
390
+ return func
391
+
392
+
393
+ # Bert for Question Answering
394
+ class BertForQuestionAnsweringConfig(Config):
395
+ @staticmethod
396
+ def gen_criterion(top_k=1):
397
+ def func(input, target: List[Tensor], mask=None):
398
+ """
399
+ Args:
400
+ input (Tensor): Input tensor of shape [B, C, 2].
401
+ target (List[Tensor]):
402
+ Start Positions of shape [B].
403
+ End Positions of shape [B].
404
+
405
+ Returns:
406
+ loss (Tensor): Scalar tensor representing the loss.
407
+ mask (Tensor): Boolean mask tensor of shape [B].
408
+ """
409
+ start_positions, end_positions = target
410
+
411
+ if mask is None:
412
+ mask = torch.ones_like(start_positions, dtype=torch.float32, device=start_positions.device)
413
+
414
+ start_logits, end_logits = input.split(1, dim=-1)
415
+ start_logits = start_logits.squeeze(-1).contiguous()
416
+ end_logits = end_logits.squeeze(-1).contiguous()
417
+
418
+ # If we are on multi-GPU, split add a dimension
419
+ if len(start_positions.size()) > 1:
420
+ start_positions = start_positions.squeeze(-1)
421
+ if len(end_positions.size()) > 1:
422
+ end_positions = end_positions.squeeze(-1)
423
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
424
+ ignored_index = start_logits.size(1)
425
+ start_positions = start_positions.clamp(0, ignored_index)
426
+ end_positions = end_positions.clamp(0, ignored_index)
427
+
428
+ masked_start_losses = F.cross_entropy(start_logits, start_positions, ignore_index=ignored_index, reduction="none")
429
+ start_loss = torch.sum(mask * masked_start_losses) / (torch.sum(mask) + 1e-6)
430
+ masked_end_losses = F.cross_entropy(end_logits, end_positions, ignore_index=ignored_index, reduction="none")
431
+ end_loss = torch.sum(mask * masked_end_losses) / (torch.sum(mask) + 1e-6)
432
+
433
+ with torch.no_grad():
434
+ topk_values, topk_indices = torch.topk(start_logits, top_k, dim=1)
435
+ mask = mask * torch.eq(topk_indices, start_positions[:, None]).any(dim=1).to(start_logits.dtype)
436
+ topk_values, topk_indices = torch.topk(end_logits, top_k, dim=1)
437
+ mask = mask * torch.eq(topk_indices, end_positions[:, None]).any(dim=1).to(end_logits.dtype)
438
+
439
+ return (start_loss + end_loss) / 2, mask
440
+ return func
441
+
442
+ @staticmethod
443
+ def gen_shared_head(self):
444
+ def func(hidden_states):
445
+ """
446
+ Args:
447
+ hidden_states (Tensor): Hidden States of shape [B, C, hidden_units].
448
+
449
+ Returns:
450
+ logits (Tensor): Logits tensor of shape [B, C, 2].
451
+ """
452
+ logits = self.qa_outputs(hidden_states)
453
+ return logits
454
+ return func
455
+
456
+ @staticmethod
457
+ def gen_forward(lambdas, loss_normalization=True, top_k=1):
458
+ def func(
459
+ self,
460
+ input_ids: Optional[torch.Tensor] = None,
461
+ attention_mask: Optional[torch.Tensor] = None,
462
+ token_type_ids: Optional[torch.Tensor] = None,
463
+ position_ids: Optional[torch.Tensor] = None,
464
+ head_mask: Optional[torch.Tensor] = None,
465
+ inputs_embeds: Optional[torch.Tensor] = None,
466
+ start_positions: Optional[torch.Tensor] = None,
467
+ end_positions: Optional[torch.Tensor] = None,
468
+ output_attentions: Optional[bool] = None,
469
+ output_hidden_states: Optional[bool] = None,
470
+ return_dict: Optional[bool] = None,
471
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
472
+ r"""
473
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
474
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
475
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
476
+ are not taken into account for computing the loss.
477
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
478
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
479
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
480
+ are not taken into account for computing the loss.
481
+ """
482
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
483
+
484
+ outputs = self.bert(
485
+ input_ids,
486
+ attention_mask=attention_mask,
487
+ token_type_ids=token_type_ids,
488
+ position_ids=position_ids,
489
+ head_mask=head_mask,
490
+ inputs_embeds=inputs_embeds,
491
+ output_attentions=output_attentions,
492
+ output_hidden_states=output_hidden_states,
493
+ return_dict=return_dict,
494
+ )
495
+
496
+ sequence_output = outputs[0]
497
+
498
+ logits = self.qa_outputs(sequence_output)
499
+ start_logits, end_logits = logits.split(1, dim=-1)
500
+ start_logits = start_logits.squeeze(-1).contiguous()
501
+ end_logits = end_logits.squeeze(-1).contiguous()
502
+
503
+ total_loss = None
504
+ if start_positions is not None and end_positions is not None:
505
+ shared_head = BertForQuestionAnsweringConfig.gen_shared_head(self)
506
+ criterion = BertForQuestionAnsweringConfig.gen_criterion()
507
+ total_loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, sequence_output, logits, [start_positions, end_positions], loss_normalization) # NOTE: Apply TRP!
508
+
509
+ if not return_dict:
510
+ output = (start_logits, end_logits) + outputs[2:]
511
+ return ((total_loss,) + output) if total_loss is not None else output
512
+
513
+ return QuestionAnsweringModelOutput(
514
+ loss=total_loss,
515
+ start_logits=start_logits,
516
+ end_logits=end_logits,
517
+ hidden_states=outputs.hidden_states,
518
+ attentions=outputs.attentions,
519
+ )
520
+ return func
521
+
522
+
523
+ # FCN for Semantic Segmentation
524
+ class FCNConfig(Config):
525
+ @staticmethod
526
+ def gen_criterion(top_k=1):
527
+ def func(input, target, mask=None):
528
+ """
529
+ Args:
530
+ input Tensor: input tensor of shape [B, C, H, W].
531
+ target (Tensor): Target labels of shape [B, H, W].
532
+
533
+ Returns:
534
+ loss (Tensor): Scalar tensor representing the loss.
535
+ mask (Tensor): Boolean mask tensor of shape [B, H, W].
536
+ """
537
+ if mask is None:
538
+ mask = torch.ones_like(target, dtype=torch.float32, device=target.device)
539
+
540
+ masked_loss = F.cross_entropy(input, target, ignore_index=255, reduction="none")
541
+ loss = torch.sum(mask * masked_loss) / (torch.sum(mask) + 1e-6)
542
+
543
+ with torch.no_grad():
544
+ topk_values, topk_indices = torch.topk(input, top_k, dim=1)
545
+ mask = mask * torch.eq(topk_indices, target[:, None, :, :]).any(dim=1).to(input.dtype)
546
+ # mask = mask * torch.eq(torch.argmax(x, dim=1), target).to(x.dtype)
547
+
548
+ return loss, mask
549
+ return func
550
+
551
+ @staticmethod
552
+ def gen_out_shared_head(self, input_shape):
553
+ def func(features):
554
+ """
555
+ Args:
556
+ features (Tensor): features tensor of shape [B, hidden_units, H, W].
557
+
558
+ Returns:
559
+ result (Tensors): result tensor of shape [B, C, H, W].
560
+ """
561
+ x = self.classifier(features)
562
+ result = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
563
+ return result
564
+ return func
565
+
566
+ @staticmethod
567
+ def gen_aux_shared_head(self, input_shape):
568
+ def func(features):
569
+ """
570
+ Args:
571
+ features (Tensor): features tensor of shape [B, hidden_units, H, W].
572
+
573
+ Returns:
574
+ result (Tensors): result tensor of shape [B, C, H, W].
575
+ """
576
+ x = self.aux_classifier(features)
577
+ result = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
578
+ return result
579
+ return func
580
+
581
+ @staticmethod
582
+ def gen_forward(lambdas, loss_normalization=True, top_k=1):
583
+ def func(self, images: Tensor, targets=None):
584
+ input_shape = images.shape[-2:]
585
+ # contract: features is a dict of tensors
586
+ features = self.backbone(images)
587
+
588
+ result = OrderedDict()
589
+ x = features["out"]
590
+ x = self.classifier(x)
591
+ x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
592
+ result["out"] = x
593
+
594
+ if self.aux_classifier is not None:
595
+ x = features["aux"]
596
+ x = self.aux_classifier(x)
597
+ x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
598
+ result["aux"] = x
599
+
600
+ if self.training:
601
+ torch._assert(targets is not None, "targets should not be none when in training mode")
602
+ out_shared_head = FCNConfig.gen_out_shared_head(self, input_shape)
603
+ aux_shared_head = FCNConfig.gen_aux_shared_head(self, input_shape)
604
+ criterion = FCNConfig.gen_criterion(top_k)
605
+ out_loss = trp_criterion(self.out_trp_blocks, out_shared_head, criterion, lambdas, features["out"], result["out"], targets, loss_normalization)
606
+ aux_loss = trp_criterion(self.aux_trp_blocks, aux_shared_head, criterion, lambdas, features["aux"], result["aux"], targets, loss_normalization)
607
+ loss = out_loss + 0.5 * aux_loss
608
+ return result, loss
609
+ return result
610
+ return func
611
+
612
+
613
+ # DeepLabV3Config for Semantic Segmentation
614
+ class DeepLabV3Config(FCNConfig):
615
+ pass
616
+
617
+
618
+ # Bert for Text Classification
619
+ class BertForSequenceClassificationConfig(Config):
620
+ @staticmethod
621
+ def gen_criterion():
622
+ def func(input, target, mask=None):
623
+ """
624
+ Args:
625
+ input (Tensor): Input tensor of shape [B, C].
626
+ target (Tensor): Target labels of shape [B].
627
+
628
+ Returns:
629
+ loss (Tensor): Scalar tensor representing the loss.
630
+ mask (Tensor): Boolean mask tensor of shape [B].
631
+ """
632
+ if mask is None:
633
+ mask = torch.ones_like(target, dtype=torch.float32, device=target.device)
634
+
635
+ unmasked_loss = F.cross_entropy(input, target, reduction="none")
636
+ loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
637
+
638
+ with torch.no_grad():
639
+ mask = mask * torch.eq(torch.argmax(input, dim=1), target).to(input.dtype)
640
+
641
+ return loss, mask
642
+ return func
643
+
644
+ @staticmethod
645
+ def gen_shared_head(self):
646
+ def func(hidden_states):
647
+ """
648
+ Args:
649
+ hidden_states (Tensor): Hidden States of shape [B, hidden_units].
650
+
651
+ Returns:
652
+ logits (Tensor): Logits tensor of shape [B, C].
653
+ """
654
+ logits = self.classifier(hidden_states)
655
+ return logits
656
+ return func
657
+
658
+ @staticmethod
659
+ def gen_forward(lambdas, loss_normalization=False):
660
+ def func(
661
+ self,
662
+ input_ids: Optional[torch.Tensor] = None,
663
+ attention_mask: Optional[torch.Tensor] = None,
664
+ token_type_ids: Optional[torch.Tensor] = None,
665
+ position_ids: Optional[torch.Tensor] = None,
666
+ head_mask: Optional[torch.Tensor] = None,
667
+ inputs_embeds: Optional[torch.Tensor] = None,
668
+ labels: Optional[torch.Tensor] = None,
669
+ output_attentions: Optional[bool] = None,
670
+ output_hidden_states: Optional[bool] = None,
671
+ return_dict: Optional[bool] = None,
672
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
673
+ r"""
674
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
675
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
676
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
677
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
678
+ """
679
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
680
+
681
+ outputs = self.bert(
682
+ input_ids,
683
+ attention_mask=attention_mask,
684
+ token_type_ids=token_type_ids,
685
+ position_ids=position_ids,
686
+ head_mask=head_mask,
687
+ inputs_embeds=inputs_embeds,
688
+ output_attentions=output_attentions,
689
+ output_hidden_states=output_hidden_states,
690
+ return_dict=return_dict,
691
+ )
692
+
693
+ pooled_output = outputs[1]
694
+
695
+ pooled_output = self.dropout(pooled_output)
696
+ logits = self.classifier(pooled_output)
697
+
698
+ loss = None
699
+ if labels is not None:
700
+ assert self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int) # TODO: remove this
701
+ if self.config.problem_type is None:
702
+ if self.num_labels == 1:
703
+ self.config.problem_type = "regression"
704
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
705
+ self.config.problem_type = "single_label_classification"
706
+ else:
707
+ self.config.problem_type = "multi_label_classification"
708
+
709
+ if self.config.problem_type == "regression":
710
+ if self.num_labels == 1:
711
+ loss = F.mse_loss(logits.squeeze(), labels.squeeze())
712
+ else:
713
+ loss = F.mse_loss(logits, labels)
714
+ elif self.config.problem_type == "single_label_classification":
715
+ shared_head = BertForSequenceClassificationConfig.gen_shared_head(self)
716
+ criterion = BertForSequenceClassificationConfig.gen_criterion()
717
+ loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, pooled_output, logits, labels, loss_normalization)
718
+ elif self.config.problem_type == "multi_label_classification":
719
+ loss = F.binary_cross_entropy_with_logits(logits, labels)
720
+ if not return_dict:
721
+ output = (logits,) + outputs[2:]
722
+ return ((loss,) + output) if loss is not None else output
723
+
724
+ return SequenceClassifierOutput(
725
+ loss=loss,
726
+ logits=logits,
727
+ hidden_states=outputs.hidden_states,
728
+ attentions=outputs.attentions,
729
+ )
730
+ return func
731
+
732
+
733
+ # Boberta for Text Classification
734
+ class RobertaForSequenceClassificationConfig(BertForSequenceClassificationConfig):
735
+ @staticmethod
736
+ def gen_shared_head(self):
737
+ def func(hidden_states):
738
+ """
739
+ Args:
740
+ hidden_states (Tensor): Hidden States of shape [B, hidden_units].
741
+
742
+ Returns:
743
+ logits (Tensor): Logits tensor of shape [B, C].
744
+ """
745
+ logits = self.classifier(hidden_states)
746
+ return logits
747
+ return func
748
+
749
+ @staticmethod
750
+ def gen_forward(lambdas, loss_normalization=False):
751
+ def func(
752
+ self,
753
+ input_ids: Optional[torch.LongTensor] = None,
754
+ attention_mask: Optional[torch.FloatTensor] = None,
755
+ token_type_ids: Optional[torch.LongTensor] = None,
756
+ position_ids: Optional[torch.LongTensor] = None,
757
+ head_mask: Optional[torch.FloatTensor] = None,
758
+ inputs_embeds: Optional[torch.FloatTensor] = None,
759
+ labels: Optional[torch.LongTensor] = None,
760
+ output_attentions: Optional[bool] = None,
761
+ output_hidden_states: Optional[bool] = None,
762
+ return_dict: Optional[bool] = None,
763
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
764
+ r"""
765
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
766
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
767
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
768
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
769
+ """
770
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
771
+
772
+ outputs = self.roberta(
773
+ input_ids,
774
+ attention_mask=attention_mask,
775
+ token_type_ids=token_type_ids,
776
+ position_ids=position_ids,
777
+ head_mask=head_mask,
778
+ inputs_embeds=inputs_embeds,
779
+ output_attentions=output_attentions,
780
+ output_hidden_states=output_hidden_states,
781
+ return_dict=return_dict,
782
+ )
783
+ sequence_output = outputs[0]
784
+ logits = self.classifier(sequence_output)
785
+
786
+ loss = None
787
+ if labels is not None:
788
+ assert self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int) # TODO: remove this
789
+ # move labels to correct device to enable model parallelism
790
+ labels = labels.to(logits.device)
791
+ if self.config.problem_type is None:
792
+ if self.num_labels == 1:
793
+ self.config.problem_type = "regression"
794
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
795
+ self.config.problem_type = "single_label_classification"
796
+ else:
797
+ self.config.problem_type = "multi_label_classification"
798
+
799
+ if self.config.problem_type == "regression":
800
+ if self.num_labels == 1:
801
+ loss = F.mse_loss(logits.squeeze(), labels.squeeze())
802
+ else:
803
+ loss = F.mse_loss(logits, labels)
804
+ elif self.config.problem_type == "single_label_classification":
805
+ shared_head = BertForSequenceClassificationConfig.gen_shared_head(self)
806
+ criterion = BertForSequenceClassificationConfig.gen_criterion()
807
+ loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, sequence_output, logits, labels, loss_normalization)
808
+ elif self.config.problem_type == "multi_label_classification":
809
+ loss = F.binary_cross_entropy_with_logits(logits, labels)
810
+
811
+ if not return_dict:
812
+ output = (logits,) + outputs[2:]
813
+ return ((loss,) + output) if loss is not None else output
814
+
815
+ return SequenceClassifierOutput(
816
+ loss=loss,
817
+ logits=logits,
818
+ hidden_states=outputs.hidden_states,
819
+ attentions=outputs.attentions,
820
+ )
821
+ return func
822
+
823
+
824
+ # Wav2Vec2 for Speech Recognition
825
+ class Wav2Vec2ForCTCConfig(Config):
826
+ _HIDDEN_STATES_START_POSITION = 2
827
+
828
+ @staticmethod
829
+ def greedy_decode_ctc(
830
+ log_probs: torch.Tensor,
831
+ input_lengths: torch.Tensor,
832
+ blank_token_id: int,
833
+ target_lengths: torch.Tensor
834
+ ):
835
+ """
836
+ Convert logits to flattened predictions that match the shape of flattened_targets.
837
+
838
+ Args:
839
+ log_probs: [B, L, V] - log-softmax output
840
+ input_lengths: [B] - actual length of each input
841
+ blank_token_id: int - index of blank token
842
+ target_lengths: [B] - used to determine how many predictions to keep per sample
843
+
844
+ Returns:
845
+ flattened_predictions: 1D tensor, same total length as sum(target_lengths)
846
+ """
847
+ batch_size = log_probs.size(0)
848
+ decoded_all = []
849
+
850
+ predicted_ids = log_probs.argmax(dim=-1) # [B, L]
851
+
852
+ for i in range(batch_size):
853
+ pred = predicted_ids[i][:input_lengths[i]] # [Li]
854
+ prev = None
855
+ decoded = []
856
+ for token in pred:
857
+ token = token.item()
858
+ if token != blank_token_id and token != prev:
859
+ decoded.append(token)
860
+ prev = token
861
+ # Trim or pad to match target_lengths[i]
862
+ tgt_len = target_lengths[i].item()
863
+ if len(decoded) >= tgt_len:
864
+ decoded = decoded[:tgt_len]
865
+ else:
866
+ decoded = decoded + [blank_token_id] * (tgt_len - len(decoded)) # pad with blank
867
+ decoded_all.extend(decoded)
868
+
869
+ return torch.tensor(decoded_all, dtype=torch.long, device=log_probs.device) # shape: [sum(target_lengths)]
870
+
871
+ @staticmethod
872
+ def gen_criterion(input_lengths: Tensor, pad_token_id: int, ctc_zero_infinity: bool):
873
+ def func(logits: Tensor, labels: Tensor, mask=None):
874
+ """
875
+ Args:
876
+ logits (Tensor): Log Probablities of shape [B, L, V].
877
+ labels (Tensor): Flattened Targets of shape [B, L'].
878
+
879
+ Returns:
880
+ loss (Tensor): Scalar tensor representing the loss.
881
+ mask (Tensor): Boolean mask tensor of shape [B].
882
+ """
883
+ if mask is None:
884
+ mask = torch.ones_like(input_lengths, dtype=torch.float32, device=input_lengths.device)
885
+
886
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
887
+ labels_mask = labels >= 0
888
+ target_lengths = labels_mask.sum(-1)
889
+ flattened_targets = labels.masked_select(labels_mask)
890
+ with torch.backends.cudnn.flags(enabled=False):
891
+ masked_losses = nn.functional.ctc_loss(log_probs, flattened_targets, input_lengths, target_lengths, blank=pad_token_id, reduction="none", zero_infinity=ctc_zero_infinity)
892
+ loss = torch.sum(mask * masked_losses) / (torch.sum(mask) + 1e-6)
893
+
894
+ with torch.no_grad():
895
+ thres = 0.5
896
+ flattened_predictions = Wav2Vec2ForCTCConfig.greedy_decode_ctc(
897
+ log_probs.transpose(0, 1), # [B, T, V]
898
+ input_lengths=input_lengths,
899
+ blank_token_id=pad_token_id,
900
+ target_lengths=target_lengths
901
+ )
902
+ token_wise_mask = torch.eq(flattened_predictions, flattened_targets).to(flattened_targets.dtype)
903
+ segment_ids = torch.arange(len(target_lengths), device=target_lengths.device).repeat_interleave(target_lengths)
904
+ sequence_wise_mask = torch.zeros(len(target_lengths), dtype=target_lengths.dtype, device=token_wise_mask.device).scatter_add(0, segment_ids, token_wise_mask)
905
+ mask = mask * torch.ge(sequence_wise_mask, thres * target_lengths).to(flattened_targets.dtype)
906
+
907
+ return loss, mask
908
+ return func
909
+
910
+ @staticmethod
911
+ def gen_shared_head(self):
912
+ def func(hidden_states):
913
+ """
914
+ Args:
915
+ hidden_states (Tensor): Hidden States of shape [B, C, hidden_units].
916
+
917
+ Returns:
918
+ logits (Tensor): Logits tensor of shape [B, C, 2].
919
+ """
920
+ logits = self.lm_head(hidden_states)
921
+ # log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
922
+ return logits
923
+ return func
924
+
925
+ @staticmethod
926
+ def gen_forward(lambdas, loss_normalization=False):
927
+ def func(
928
+ self,
929
+ input_values: Optional[torch.Tensor],
930
+ attention_mask: Optional[torch.Tensor] = None,
931
+ output_attentions: Optional[bool] = None,
932
+ output_hidden_states: Optional[bool] = None,
933
+ return_dict: Optional[bool] = None,
934
+ labels: Optional[torch.Tensor] = None,
935
+ ) -> Union[Tuple, CausalLMOutput]:
936
+ r"""
937
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
938
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
939
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
940
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
941
+ config.vocab_size - 1]`.
942
+ """
943
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
944
+
945
+ if labels is not None and labels.max() >= self.config.vocab_size:
946
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
947
+
948
+ outputs = self.wav2vec2(
949
+ input_values,
950
+ attention_mask=attention_mask,
951
+ output_attentions=output_attentions,
952
+ output_hidden_states=output_hidden_states,
953
+ return_dict=return_dict,
954
+ )
955
+
956
+ hidden_states = outputs[0]
957
+ hidden_states = self.dropout(hidden_states)
958
+
959
+ logits = self.lm_head(hidden_states)
960
+
961
+ loss = None
962
+ if labels is not None:
963
+ # retrieve loss input_lengths from attention_mask
964
+ attention_mask = (
965
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
966
+ )
967
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
968
+ shared_head = Wav2Vec2ForCTCConfig.gen_shared_head(self)
969
+ criterion = Wav2Vec2ForCTCConfig.gen_criterion(input_lengths, self.config.pad_token_id, self.config.ctc_zero_infinity)
970
+ loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_states, logits, labels, loss_normalization) # NOTE: Apply TRP!
971
+
972
+ if not return_dict:
973
+ output = (logits,) + outputs[Wav2Vec2ForCTCConfig._HIDDEN_STATES_START_POSITION:]
974
+ return ((loss,) + output) if loss is not None else output
975
+
976
+ return CausalLMOutput(
977
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
978
+ )
979
+ return func
980
+
981
+
982
+ # MBart for Translation
983
+ class MBartForConditionalGenerationConfig(Config):
984
+ @staticmethod
985
+ def gen_criterion(vocab_size: int, top_k=1):
986
+ def func(logits, labels, mask=None):
987
+ """
988
+ Args:
989
+ logits (Tensor): Logits tensor of shape [B, L, V].
990
+ labels (Tensor): Target labels of shape [B, L].
991
+
992
+ Returns:
993
+ loss (Tensor): Scalar tensor representing the loss.
994
+ mask (Tensor): Boolean mask tensor of shape [B].
995
+ """
996
+ if mask is None:
997
+ mask = torch.ones_like(labels.view(-1), dtype=torch.float32, device=labels.device)
998
+
999
+ masked_losses = F.cross_entropy(logits.view(-1, vocab_size), labels.view(-1), reduction="none")
1000
+ loss = torch.sum(mask * masked_losses) / (torch.sum(mask) + 1e-6)
1001
+
1002
+ with torch.no_grad():
1003
+ topk_values, topk_indices = torch.topk(logits.view(-1, vocab_size), top_k, dim=1)
1004
+ mask = mask * torch.eq(topk_indices, labels.view(-1, 1)).any(dim=1).to(logits.dtype)
1005
+
1006
+ return loss, mask
1007
+ return func
1008
+
1009
+ @staticmethod
1010
+ def gen_shared_head(self):
1011
+ def func(hidden_states):
1012
+ """
1013
+ Args:
1014
+ hidden_states (Tensor): Hidden States of shape [B, L, hidden_units].
1015
+
1016
+ Returns:
1017
+ logits (Tensor): Logits tensor of shape [B, L].
1018
+ """
1019
+ logits = self.lm_head(hidden_states) + self.final_logits_bias
1020
+ return logits
1021
+ return func
1022
+
1023
+ @staticmethod
1024
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
1025
+ """
1026
+ Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
1027
+ have a single `decoder_start_token_id` in contrast to other Bart-like models.
1028
+ """
1029
+ prev_output_tokens = input_ids.clone()
1030
+
1031
+ if pad_token_id is None:
1032
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
1033
+ # replace possible -100 values in labels by `pad_token_id`
1034
+ prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)
1035
+
1036
+ index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
1037
+ decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
1038
+ prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
1039
+ prev_output_tokens[:, 0] = decoder_start_tokens
1040
+
1041
+ return prev_output_tokens
1042
+
1043
+ @staticmethod
1044
+ def gen_forward(lambdas, loss_normalization=False):
1045
+ def func(
1046
+ self,
1047
+ input_ids: torch.LongTensor = None,
1048
+ attention_mask: Optional[torch.Tensor] = None,
1049
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1050
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1051
+ head_mask: Optional[torch.Tensor] = None,
1052
+ decoder_head_mask: Optional[torch.Tensor] = None,
1053
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1054
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1055
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1056
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1057
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1058
+ labels: Optional[torch.LongTensor] = None,
1059
+ use_cache: Optional[bool] = None,
1060
+ output_attentions: Optional[bool] = None,
1061
+ output_hidden_states: Optional[bool] = None,
1062
+ return_dict: Optional[bool] = None,
1063
+ ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
1064
+ r"""
1065
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1066
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1067
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1068
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1069
+
1070
+ Returns:
1071
+
1072
+ """
1073
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1074
+
1075
+ if labels is not None:
1076
+ # if use_cache:
1077
+ # logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
1078
+ use_cache = False
1079
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
1080
+ decoder_input_ids = MBartForConditionalGenerationConfig.shift_tokens_right(labels, self.config.pad_token_id)
1081
+
1082
+ outputs = self.model(
1083
+ input_ids,
1084
+ attention_mask=attention_mask,
1085
+ decoder_input_ids=decoder_input_ids,
1086
+ encoder_outputs=encoder_outputs,
1087
+ decoder_attention_mask=decoder_attention_mask,
1088
+ head_mask=head_mask,
1089
+ decoder_head_mask=decoder_head_mask,
1090
+ cross_attn_head_mask=cross_attn_head_mask,
1091
+ past_key_values=past_key_values,
1092
+ inputs_embeds=inputs_embeds,
1093
+ decoder_inputs_embeds=decoder_inputs_embeds,
1094
+ use_cache=use_cache,
1095
+ output_attentions=output_attentions,
1096
+ output_hidden_states=output_hidden_states,
1097
+ return_dict=return_dict,
1098
+ )
1099
+ lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
1100
+
1101
+ masked_lm_loss = None
1102
+ if labels is not None:
1103
+ shared_head = MBartForConditionalGenerationConfig.gen_shared_head(self)
1104
+ criterion = MBartForConditionalGenerationConfig.gen_criterion(self.config.vocab_size)
1105
+ masked_lm_loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, outputs[0], lm_logits, labels, loss_normalization)
1106
+
1107
+ if not return_dict:
1108
+ output = (lm_logits,) + outputs[1:]
1109
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1110
+
1111
+ return Seq2SeqLMOutput(
1112
+ loss=masked_lm_loss,
1113
+ logits=lm_logits,
1114
+ past_key_values=outputs.past_key_values,
1115
+ decoder_hidden_states=outputs.decoder_hidden_states,
1116
+ decoder_attentions=outputs.decoder_attentions,
1117
+ cross_attentions=outputs.cross_attentions,
1118
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1119
+ encoder_hidden_states=outputs.encoder_hidden_states,
1120
+ encoder_attentions=outputs.encoder_attentions,
1121
+ )
1122
+ return func
1123
+
1124
+
1125
+ def apply_trp(model, depths: int, p: float, lambdas: List[float], **kwargs):
1126
+ if isinstance(model, transformers.Wav2Vec2ForSequenceClassification):
1127
+ print("✅ Applying TRP to Wav2Vec2 for Audio Classification...")
1128
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
1129
+ model.forward = types.MethodType(Wav2Vec2ForSequenceClassificationConfig.gen_forward(lambdas, False), model)
1130
+ elif isinstance(model, MobileNetV2):
1131
+ print("✅ Applying TRP to MobileNetV2 for Image Classification...")
1132
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1280, p) for _ in lambdas])
1133
+ model.forward = types.MethodType(MobileNetV2Config.gen_forward(lambdas, True, label_smoothing=0.0, top_k=1), model)
1134
+ elif isinstance(model, ResNet):
1135
+ print("✅ Applying TRP to ResNet for Image Classification...")
1136
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 2048, p) for _ in lambdas])
1137
+ model.forward = types.MethodType(ResNetConfig.gen_forward(lambdas, True, label_smoothing=0.0, top_k=1), model)
1138
+ elif isinstance(model, EfficientNet):
1139
+ print("✅ Applying TRP to EfficientNet for Image Classification...")
1140
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1280, p) for _ in lambdas])
1141
+ model.forward = types.MethodType(EfficientNetConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
1142
+ elif isinstance(model, VisionTransformer):
1143
+ print("✅ Applying TRP to VisionTransformer for Image Classification...")
1144
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
1145
+ model.forward = types.MethodType(VisionTransformerConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
1146
+ elif isinstance(model, transformers.BertForQuestionAnswering):
1147
+ print("✅ Applying TRP to Bert for Question Answering...")
1148
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
1149
+ model.forward = types.MethodType(BertForQuestionAnsweringConfig.gen_forward(lambdas, True, 1), model)
1150
+ elif isinstance(model, FCN):
1151
+ print("✅ Applying TRP to FCN for Semantic Segmentation...")
1152
+ model.out_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 2048, p, dim=1) for _ in lambdas])
1153
+ model.aux_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p, dim=1) for _ in lambdas])
1154
+ model.forward = types.MethodType(FCNConfig.gen_forward(lambdas, True, 1), model)
1155
+ elif isinstance(model, DeepLabV3):
1156
+ print("✅ Applying TRP to DeepLabV3 for Semantic Segmentation...")
1157
+ model.out_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 2048, p, dim=1) for _ in lambdas])
1158
+ model.aux_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p, dim=1) for _ in lambdas])
1159
+ model.forward = types.MethodType(DeepLabV3Config.gen_forward(lambdas, True, 1), model)
1160
+ elif isinstance(model, transformers.BertForSequenceClassification):
1161
+ print("✅ Applying TRP to Bert for Text Classification...")
1162
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
1163
+ model.forward = types.MethodType(BertForSequenceClassificationConfig.gen_forward(lambdas, False), model)
1164
+ elif isinstance(model, transformers.RobertaForSequenceClassification):
1165
+ print("✅ Applying TRP to Roberta for Text Classification...")
1166
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
1167
+ model.forward = types.MethodType(RobertaForSequenceClassificationConfig.gen_forward(lambdas, False), model)
1168
+ elif isinstance(model, transformers.Wav2Vec2ForCTC):
1169
+ print("✅ Applying TRP to Wav2Vec2 for Speech Recognition...")
1170
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p) for _ in lambdas])
1171
+ model.forward = types.MethodType(Wav2Vec2ForCTCConfig.gen_forward(lambdas, False), model)
1172
+ elif isinstance(model, transformers.MBartForConditionalGeneration):
1173
+ print("✅ Applying TRP to MBart for Translation...")
1174
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p) for _ in lambdas])
1175
+ model.forward = types.MethodType(MBartForConditionalGenerationConfig.gen_forward(lambdas, False), model)
1176
+ else:
1177
+ torch._assert(
1178
+ isinstance(model, transformers.Wav2Vec2ForSequenceClassification),
1179
+ "The model should be an object of [`Wav2Vec2ForSequenceClassification`].")
1180
+
1181
+ return model
hpo-examples/image-classification/__pycache__/presets.cpython-310.pyc ADDED
Binary file (2.31 kB). View file
 
hpo-examples/image-classification/__pycache__/sampler.cpython-310.pyc ADDED
Binary file (2.41 kB). View file
 
hpo-examples/image-classification/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (5.29 kB). View file
 
hpo-examples/image-classification/__pycache__/trplib.cpython-310.pyc ADDED
Binary file (37.5 kB). View file
 
hpo-examples/image-classification/__pycache__/utils.cpython-310.pyc ADDED
Binary file (14.9 kB). View file
 
hpo-examples/image-classification/efficientnet_v2_m/model_7.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f4b45a082517a3e60498e92710f5a97b5869515db52e536d9c92a3b68ae4e8f
3
+ size 454515355
hpo-examples/image-classification/mobilenetv2/model_32.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f240e3b954e1b1786878733c3045125461fa67bcb6ce50a200d5ec6e46081bc7
3
+ size 48002008
hpo-examples/image-classification/presets.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.transforms import autoaugment, transforms
3
+ from torchvision.transforms.functional import InterpolationMode
4
+
5
+
6
+ class ClassificationPresetTrain:
7
+ def __init__(
8
+ self,
9
+ *,
10
+ crop_size,
11
+ mean=(0.485, 0.456, 0.406),
12
+ std=(0.229, 0.224, 0.225),
13
+ interpolation=InterpolationMode.BILINEAR,
14
+ hflip_prob=0.5,
15
+ auto_augment_policy=None,
16
+ ra_magnitude=9,
17
+ augmix_severity=3,
18
+ random_erase_prob=0.0,
19
+ ):
20
+ trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
21
+ if hflip_prob > 0:
22
+ trans.append(transforms.RandomHorizontalFlip(hflip_prob))
23
+ if auto_augment_policy is not None:
24
+ if auto_augment_policy == "ra":
25
+ trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
26
+ elif auto_augment_policy == "ta_wide":
27
+ trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
28
+ elif auto_augment_policy == "augmix":
29
+ trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity))
30
+ else:
31
+ aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
32
+ trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
33
+ trans.extend(
34
+ [
35
+ transforms.PILToTensor(),
36
+ transforms.ConvertImageDtype(torch.float),
37
+ transforms.Normalize(mean=mean, std=std),
38
+ ]
39
+ )
40
+ if random_erase_prob > 0:
41
+ trans.append(transforms.RandomErasing(p=random_erase_prob))
42
+
43
+ self.transforms = transforms.Compose(trans)
44
+
45
+ def __call__(self, img):
46
+ return self.transforms(img)
47
+
48
+
49
+ class ClassificationPresetEval:
50
+ def __init__(
51
+ self,
52
+ *,
53
+ crop_size,
54
+ resize_size=256,
55
+ mean=(0.485, 0.456, 0.406),
56
+ std=(0.229, 0.224, 0.225),
57
+ interpolation=InterpolationMode.BILINEAR,
58
+ ):
59
+
60
+ self.transforms = transforms.Compose(
61
+ [
62
+ transforms.Resize(resize_size, interpolation=interpolation),
63
+ transforms.CenterCrop(crop_size),
64
+ transforms.PILToTensor(),
65
+ transforms.ConvertImageDtype(torch.float),
66
+ transforms.Normalize(mean=mean, std=std),
67
+ ]
68
+ )
69
+
70
+ def __call__(self, img):
71
+ return self.transforms(img)
hpo-examples/image-classification/resnet50/model_35.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b304e1309e3143b9e7109d5b7263369e4e93c949dde9bab44b3d3b193d16361
3
+ size 255177167
hpo-examples/image-classification/run.sh ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ✅ --lr 0.00002 Acc@1 71.878 Acc@5 90.286 -> Acc@1 72.104 Acc@5 90.316 (with normalization)
2
+ torchrun --nproc_per_node=4 train.py\
3
+ --data-path /home/cs/Documents/datasets/imagenet\
4
+ --model mobilenet_v2 --output-dir mobilenet_v2 --weights MobileNet_V2_Weights.IMAGENET1K_V1\
5
+ --batch-size 192 --epochs 40 --lr 0.0004 --lr-step-size 10 --lr-gamma 0.5 --wd 0.00004 --apply-trp --trp-depths 1 --trp-p 0.15 --trp-lambdas 0.4 0.2 0.1
6
+ # torchrun --nproc_per_node=4 train.py\
7
+ # --data-path /home/cs/Documents/datasets/imagenet\
8
+ # --model mobilenet_v2 --resume mobilenet_v2/model_32.pth --test-only
9
+
10
+
11
+ # ✅ --lr 0.0002 Acc@1 76.130 Acc@5 92.862 -> Acc@1 77.234 Acc@5 93.322 (with normalization)
12
+ torchrun --nproc_per_node=4 train.py\
13
+ --data-path /home/cs/Documents/datasets/imagenet\
14
+ --model resnet50 --output-dir resnet50 --weights ResNet50_Weights.IMAGENET1K_V1\
15
+ --batch-size 64 --epochs 40 --lr 0.0004 --lr-step-size 10 --lr-gamma 0.5 --print-freq 100\
16
+ --apply-trp --trp-depths 1 --trp-p 0.2 --trp-lambdas 0.4 0.2 0.1
17
+ # torchrun --nproc_per_node=4 train.py\
18
+ # --data-path /home/cs/Documents/datasets/imagenet\
19
+ # --model resnet50 --resume resnet50/model_35.pth --test-only
20
+
21
+
22
+ # ✅ Test: Acc@1 85.218 Acc@5 97.208
23
+ torchrun --nproc_per_node=4 train.py \
24
+ --data-path /home/cs/Documents/datasets/imagenet\
25
+ --model efficientnet_v2_m --output-dir efficientnet_v2_m --weights EfficientNet_V2_M_Weights.IMAGENET1K_V1\
26
+ --epochs 10 --batch-size 64 --lr 5e-9 --lr-scheduler cosineannealinglr --weight-decay 0.00002 \
27
+ --lr-warmup-method constant --lr-warmup-epochs 8 --lr-warmup-decay 0. \
28
+ --auto-augment ta_wide --random-erase 0.1 --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --norm-weight-decay 0.0 \
29
+ --train-crop-size 384 --val-crop-size 480 --val-resize-size 480 --ra-sampler --ra-reps 4 --print-freq 100\
30
+ --apply-trp --trp-depths 1 --trp-p 0.2 --trp-lambdas 0.4 0.2 0.1
31
+ # torchrun --nproc_per_node=4 train.py\
32
+ # --data-path /home/cs/Documents/datasets/imagenet\
33
+ # --model efficientnet_v2_m --resume efficientnet_v2_m/model_7.pth --test-only\
34
+ # --val-crop-size 480 --val-resize-size 480
35
+
36
+
37
+ # ✅ Test: Acc@1 81.092 Acc@5 95.304
38
+ torchrun --nproc_per_node=4 train.py\
39
+ --data-path /home/cs/Documents/datasets/imagenet\
40
+ --model vit_b_16 --output-dir vit_b_16 --weights ViT_B_16_Weights.IMAGENET1K_V1\
41
+ --epochs 5 --batch-size 196 --opt adamw --lr 5e-9 --lr-scheduler cosineannealinglr --wd 0.3\
42
+ --lr-warmup-method constant --lr-warmup-epochs 3 --lr-warmup-decay 0. \
43
+ --amp --label-smoothing 0.11 --mixup-alpha 0.2 --auto-augment ra --clip-grad-norm 1 --cutmix-alpha 1.0\
44
+ --apply-trp --trp-depths 1 --trp-p 0.1 --trp-lambdas 0.4 0.2 0.1 --print-freq 100
45
+ # torchrun --nproc_per_node=4 train.py\
46
+ # --data-path /home/cs/Documents/datasets/imagenet\
47
+ # --model vit_b_16 --resume vit_b_16/model_4.pth --test-only
48
+
49
+
hpo-examples/image-classification/sampler.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+
6
+
7
+ class RASampler(torch.utils.data.Sampler):
8
+ """Sampler that restricts data loading to a subset of the dataset for distributed,
9
+ with repeated augmentation.
10
+ It ensures that different each augmented version of a sample will be visible to a
11
+ different process (GPU).
12
+ Heavily based on 'torch.utils.data.DistributedSampler'.
13
+
14
+ This is borrowed from the DeiT Repo:
15
+ https://github.com/facebookresearch/deit/blob/main/samplers.py
16
+ """
17
+
18
+ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
19
+ if num_replicas is None:
20
+ if not dist.is_available():
21
+ raise RuntimeError("Requires distributed package to be available!")
22
+ num_replicas = dist.get_world_size()
23
+ if rank is None:
24
+ if not dist.is_available():
25
+ raise RuntimeError("Requires distributed package to be available!")
26
+ rank = dist.get_rank()
27
+ self.dataset = dataset
28
+ self.num_replicas = num_replicas
29
+ self.rank = rank
30
+ self.epoch = 0
31
+ self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
32
+ self.total_size = self.num_samples * self.num_replicas
33
+ self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
34
+ self.shuffle = shuffle
35
+ self.seed = seed
36
+ self.repetitions = repetitions
37
+
38
+ def __iter__(self):
39
+ if self.shuffle:
40
+ # Deterministically shuffle based on epoch
41
+ g = torch.Generator()
42
+ g.manual_seed(self.seed + self.epoch)
43
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
44
+ else:
45
+ indices = list(range(len(self.dataset)))
46
+
47
+ # Add extra samples to make it evenly divisible
48
+ indices = [ele for ele in indices for i in range(self.repetitions)]
49
+ indices += indices[: (self.total_size - len(indices))]
50
+ assert len(indices) == self.total_size
51
+
52
+ # Subsample
53
+ indices = indices[self.rank : self.total_size : self.num_replicas]
54
+ assert len(indices) == self.num_samples
55
+
56
+ return iter(indices[: self.num_selected_samples])
57
+
58
+ def __len__(self):
59
+ return self.num_selected_samples
60
+
61
+ def set_epoch(self, epoch):
62
+ self.epoch = epoch
hpo-examples/image-classification/train.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ import time
4
+ import warnings
5
+
6
+ import presets
7
+ import torch
8
+ import torch.utils.data
9
+ import torchvision
10
+ import transforms
11
+ import utils
12
+ from sampler import RASampler
13
+ from torch import nn
14
+ from torch.utils.data.dataloader import default_collate
15
+ from torchvision.transforms.functional import InterpolationMode
16
+
17
+ from trplib import apply_trp
18
+
19
+
20
+ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
21
+ model.train()
22
+ metric_logger = utils.MetricLogger(delimiter=" ")
23
+ metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
24
+ metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
25
+
26
+ header = f"Epoch: [{epoch}]"
27
+ for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
28
+ start_time = time.time()
29
+ image, target = image.to(device), target.to(device)
30
+ with torch.amp.autocast("cuda", enabled=scaler is not None):
31
+ # output = model(image)
32
+ # loss = criterion(output, target)
33
+ output, loss = model(image, target)
34
+
35
+ optimizer.zero_grad()
36
+ if scaler is not None:
37
+ scaler.scale(loss).backward()
38
+ if args.clip_grad_norm is not None:
39
+ # we should unscale the gradients of optimizer's assigned params if do gradient clipping
40
+ scaler.unscale_(optimizer)
41
+ nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
42
+ scaler.step(optimizer)
43
+ scaler.update()
44
+ else:
45
+ loss.backward()
46
+ if args.clip_grad_norm is not None:
47
+ nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
48
+ optimizer.step()
49
+
50
+ if model_ema and i % args.model_ema_steps == 0:
51
+ model_ema.update_parameters(model)
52
+ if epoch < args.lr_warmup_epochs:
53
+ # Reset ema buffer to keep copying weights during warmup period
54
+ model_ema.n_averaged.fill_(0)
55
+
56
+ acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
57
+ batch_size = image.shape[0]
58
+ metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
59
+ metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
60
+ metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
61
+ metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
62
+
63
+
64
+ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
65
+ model.eval()
66
+ metric_logger = utils.MetricLogger(delimiter=" ")
67
+ header = f"Test: {log_suffix}"
68
+
69
+ num_processed_samples = 0
70
+ with torch.inference_mode():
71
+ for image, target in metric_logger.log_every(data_loader, print_freq, header):
72
+ image = image.to(device, non_blocking=True)
73
+ target = target.to(device, non_blocking=True)
74
+ output = model(image)
75
+ loss = criterion(output, target)
76
+
77
+ acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
78
+ # FIXME need to take into account that the datasets
79
+ # could have been padded in distributed setup
80
+ batch_size = image.shape[0]
81
+ metric_logger.update(loss=loss.item())
82
+ metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
83
+ metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
84
+ num_processed_samples += batch_size
85
+ # gather the stats from all processes
86
+
87
+ num_processed_samples = utils.reduce_across_processes(num_processed_samples)
88
+ if (
89
+ hasattr(data_loader.dataset, "__len__")
90
+ and len(data_loader.dataset) != num_processed_samples
91
+ and torch.distributed.get_rank() == 0
92
+ ):
93
+ # See FIXME above
94
+ warnings.warn(
95
+ f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
96
+ "samples were used for the validation, which might bias the results. "
97
+ "Try adjusting the batch size and / or the world size. "
98
+ "Setting the world size to 1 is always a safe bet."
99
+ )
100
+
101
+ metric_logger.synchronize_between_processes()
102
+
103
+ print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
104
+ return metric_logger.acc1.global_avg
105
+
106
+
107
+ def _get_cache_path(filepath):
108
+ import hashlib
109
+
110
+ h = hashlib.sha1(filepath.encode()).hexdigest()
111
+ cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
112
+ cache_path = os.path.expanduser(cache_path)
113
+ return cache_path
114
+
115
+
116
+ def load_data(traindir, valdir, args):
117
+ # Data loading code
118
+ print("Loading data")
119
+ val_resize_size, val_crop_size, train_crop_size = (
120
+ args.val_resize_size,
121
+ args.val_crop_size,
122
+ args.train_crop_size,
123
+ )
124
+ interpolation = InterpolationMode(args.interpolation)
125
+
126
+ print("Loading training data")
127
+ st = time.time()
128
+ cache_path = _get_cache_path(traindir)
129
+ if args.cache_dataset and os.path.exists(cache_path):
130
+ # Attention, as the transforms are also cached!
131
+ print(f"Loading dataset_train from {cache_path}")
132
+ dataset, _ = torch.load(cache_path)
133
+ else:
134
+ auto_augment_policy = getattr(args, "auto_augment", None)
135
+ random_erase_prob = getattr(args, "random_erase", 0.0)
136
+ ra_magnitude = args.ra_magnitude
137
+ augmix_severity = args.augmix_severity
138
+ dataset = torchvision.datasets.ImageFolder(
139
+ traindir,
140
+ presets.ClassificationPresetTrain(
141
+ crop_size=train_crop_size,
142
+ interpolation=interpolation,
143
+ auto_augment_policy=auto_augment_policy,
144
+ random_erase_prob=random_erase_prob,
145
+ ra_magnitude=ra_magnitude,
146
+ augmix_severity=augmix_severity,
147
+ ),
148
+ )
149
+ if args.cache_dataset:
150
+ print(f"Saving dataset_train to {cache_path}")
151
+ utils.mkdir(os.path.dirname(cache_path))
152
+ utils.save_on_master((dataset, traindir), cache_path)
153
+ print("Took", time.time() - st)
154
+
155
+ print("Loading validation data")
156
+ cache_path = _get_cache_path(valdir)
157
+ if args.cache_dataset and os.path.exists(cache_path):
158
+ # Attention, as the transforms are also cached!
159
+ print(f"Loading dataset_test from {cache_path}")
160
+ dataset_test, _ = torch.load(cache_path)
161
+ else:
162
+ if args.weights and args.test_only:
163
+ weights = torchvision.models.get_weight(args.weights)
164
+ preprocessing = weights.transforms()
165
+ else:
166
+ preprocessing = presets.ClassificationPresetEval(
167
+ crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
168
+ )
169
+
170
+ dataset_test = torchvision.datasets.ImageFolder(
171
+ valdir,
172
+ preprocessing,
173
+ )
174
+ if args.cache_dataset:
175
+ print(f"Saving dataset_test to {cache_path}")
176
+ utils.mkdir(os.path.dirname(cache_path))
177
+ utils.save_on_master((dataset_test, valdir), cache_path)
178
+
179
+ print("Creating data loaders")
180
+ if args.distributed:
181
+ if hasattr(args, "ra_sampler") and args.ra_sampler:
182
+ train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
183
+ else:
184
+ train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
185
+ test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
186
+ else:
187
+ train_sampler = torch.utils.data.RandomSampler(dataset)
188
+ test_sampler = torch.utils.data.SequentialSampler(dataset_test)
189
+
190
+ return dataset, dataset_test, train_sampler, test_sampler
191
+
192
+
193
+ def main(args):
194
+ if args.output_dir:
195
+ utils.mkdir(args.output_dir)
196
+
197
+ utils.init_distributed_mode(args)
198
+ print(args)
199
+
200
+ device = torch.device(args.device)
201
+
202
+ if args.use_deterministic_algorithms:
203
+ torch.backends.cudnn.benchmark = False
204
+ torch.use_deterministic_algorithms(True)
205
+ else:
206
+ torch.backends.cudnn.benchmark = True
207
+
208
+ train_dir = os.path.join(args.data_path, "train")
209
+ val_dir = os.path.join(args.data_path, "val")
210
+ dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
211
+
212
+ collate_fn = None
213
+ num_classes = len(dataset.classes)
214
+ mixup_transforms = []
215
+ if args.mixup_alpha > 0.0:
216
+ mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
217
+ if args.cutmix_alpha > 0.0:
218
+ mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
219
+ if mixup_transforms:
220
+ mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
221
+
222
+ def collate_fn(batch):
223
+ return mixupcutmix(*default_collate(batch))
224
+
225
+ data_loader = torch.utils.data.DataLoader(
226
+ dataset,
227
+ batch_size=args.batch_size,
228
+ sampler=train_sampler,
229
+ num_workers=args.workers,
230
+ pin_memory=True,
231
+ collate_fn=collate_fn,
232
+ )
233
+ data_loader_test = torch.utils.data.DataLoader(
234
+ dataset_test, batch_size=8, sampler=test_sampler, num_workers=args.workers, pin_memory=True
235
+ )
236
+
237
+ print("Creating model")
238
+ model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
239
+ if args.apply_trp:
240
+ model = apply_trp(model, args.trp_depths, args.trp_p, args.trp_lambdas, label_smoothing=args.label_smoothing)
241
+ model.to(device)
242
+
243
+ if args.distributed and args.sync_bn:
244
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
245
+
246
+ criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
247
+
248
+ custom_keys_weight_decay = []
249
+ if args.bias_weight_decay is not None:
250
+ custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
251
+ if args.transformer_embedding_decay is not None:
252
+ for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
253
+ custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
254
+ parameters = utils.set_weight_decay(
255
+ model,
256
+ args.weight_decay,
257
+ norm_weight_decay=args.norm_weight_decay,
258
+ custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
259
+ )
260
+
261
+ opt_name = args.opt.lower()
262
+ if opt_name.startswith("sgd"):
263
+ optimizer = torch.optim.SGD(
264
+ parameters,
265
+ lr=args.lr,
266
+ momentum=args.momentum,
267
+ weight_decay=args.weight_decay,
268
+ nesterov="nesterov" in opt_name,
269
+ )
270
+ elif opt_name == "rmsprop":
271
+ optimizer = torch.optim.RMSprop(
272
+ parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
273
+ )
274
+ elif opt_name == "adamw":
275
+ optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
276
+ else:
277
+ raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
278
+
279
+ scaler = torch.amp.GradScaler("cuda") if args.amp else None
280
+
281
+ args.lr_scheduler = args.lr_scheduler.lower()
282
+ if args.lr_scheduler == "steplr":
283
+ main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
284
+ elif args.lr_scheduler == "cosineannealinglr":
285
+ main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
286
+ optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
287
+ )
288
+ elif args.lr_scheduler == "exponentiallr":
289
+ main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
290
+ else:
291
+ raise RuntimeError(
292
+ f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
293
+ "are supported."
294
+ )
295
+
296
+ if args.lr_warmup_epochs > 0:
297
+ if args.lr_warmup_method == "linear":
298
+ warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
299
+ optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
300
+ )
301
+ elif args.lr_warmup_method == "constant":
302
+ warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
303
+ optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
304
+ )
305
+ else:
306
+ raise RuntimeError(
307
+ f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
308
+ )
309
+ lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
310
+ optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
311
+ )
312
+ else:
313
+ lr_scheduler = main_lr_scheduler
314
+
315
+ model_without_ddp = model
316
+ if args.distributed:
317
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
318
+ model_without_ddp = model.module
319
+
320
+ model_ema = None
321
+ if args.model_ema:
322
+ # Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
323
+ # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
324
+ #
325
+ # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
326
+ # We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
327
+ # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
328
+ adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
329
+ alpha = 1.0 - args.model_ema_decay
330
+ alpha = min(1.0, alpha * adjust)
331
+ model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
332
+
333
+ if args.resume:
334
+ checkpoint = torch.load(args.resume, map_location="cpu", weights_only=False)
335
+ model_without_ddp.load_state_dict(checkpoint["model"])
336
+ if not args.test_only:
337
+ optimizer.load_state_dict(checkpoint["optimizer"])
338
+ lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
339
+ args.start_epoch = checkpoint["epoch"] + 1
340
+ if model_ema:
341
+ model_ema.load_state_dict(checkpoint["model_ema"])
342
+ if scaler:
343
+ scaler.load_state_dict(checkpoint["scaler"])
344
+
345
+ if args.test_only:
346
+ # We disable the cudnn benchmarking because it can noticeably affect the accuracy
347
+ torch.backends.cudnn.benchmark = False
348
+ torch.backends.cudnn.deterministic = True
349
+ if model_ema:
350
+ evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
351
+ else:
352
+ evaluate(model, criterion, data_loader_test, device=device)
353
+ return
354
+
355
+ print("Start training")
356
+ start_time = time.time()
357
+ for epoch in range(args.start_epoch, args.epochs):
358
+ if args.distributed:
359
+ train_sampler.set_epoch(epoch)
360
+ train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
361
+ lr_scheduler.step()
362
+ evaluate(model, criterion, data_loader_test, device=device)
363
+ if model_ema:
364
+ evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
365
+ if args.output_dir:
366
+ checkpoint = {
367
+ "model": model_without_ddp.state_dict() if not args.apply_trp else {k: v for k, v in model_without_ddp.state_dict().items() if not k.startswith("trp_blocks")}, # NOTE: remove TRP heads
368
+ "optimizer": optimizer.state_dict(),
369
+ "lr_scheduler": lr_scheduler.state_dict(),
370
+ "epoch": epoch,
371
+ "args": args,
372
+ }
373
+ if model_ema:
374
+ checkpoint["model_ema"] = model_ema.state_dict() if not args.apply_trp else {k: v for k, v in model_ema.state_dict().items() if not k.startswith("trp_blocks")} # NOTE: remove TRP heads
375
+ if scaler:
376
+ checkpoint["scaler"] = scaler.state_dict()
377
+ utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
378
+ utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
379
+
380
+ total_time = time.time() - start_time
381
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
382
+ print(f"Training time {total_time_str}")
383
+
384
+
385
+ def get_args_parser(add_help=True):
386
+ import argparse
387
+
388
+ parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
389
+
390
+ parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
391
+ parser.add_argument("--model", default="resnet18", type=str, help="model name")
392
+ parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
393
+ parser.add_argument(
394
+ "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
395
+ )
396
+ parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
397
+ parser.add_argument(
398
+ "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
399
+ )
400
+ parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
401
+ parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
402
+ parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
403
+ parser.add_argument(
404
+ "--wd",
405
+ "--weight-decay",
406
+ default=1e-4,
407
+ type=float,
408
+ metavar="W",
409
+ help="weight decay (default: 1e-4)",
410
+ dest="weight_decay",
411
+ )
412
+ parser.add_argument(
413
+ "--norm-weight-decay",
414
+ default=None,
415
+ type=float,
416
+ help="weight decay for Normalization layers (default: None, same value as --wd)",
417
+ )
418
+ parser.add_argument(
419
+ "--bias-weight-decay",
420
+ default=None,
421
+ type=float,
422
+ help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
423
+ )
424
+ parser.add_argument(
425
+ "--transformer-embedding-decay",
426
+ default=None,
427
+ type=float,
428
+ help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
429
+ )
430
+ parser.add_argument(
431
+ "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
432
+ )
433
+ parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
434
+ parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
435
+ parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
436
+ parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
437
+ parser.add_argument(
438
+ "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
439
+ )
440
+ parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
441
+ parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
442
+ parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
443
+ parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
444
+ parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
445
+ parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
446
+ parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
447
+ parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
448
+ parser.add_argument(
449
+ "--cache-dataset",
450
+ dest="cache_dataset",
451
+ help="Cache the datasets for quicker initialization. It also serializes the transforms",
452
+ action="store_true",
453
+ )
454
+ parser.add_argument(
455
+ "--sync-bn",
456
+ dest="sync_bn",
457
+ help="Use sync batch norm",
458
+ action="store_true",
459
+ )
460
+ parser.add_argument(
461
+ "--test-only",
462
+ dest="test_only",
463
+ help="Only test the model",
464
+ action="store_true",
465
+ )
466
+ parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
467
+ parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
468
+ parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
469
+ parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
470
+
471
+ # Mixed precision training parameters
472
+ parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
473
+
474
+ # distributed training parameters
475
+ parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
476
+ parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
477
+ parser.add_argument(
478
+ "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
479
+ )
480
+ parser.add_argument(
481
+ "--model-ema-steps",
482
+ type=int,
483
+ default=32,
484
+ help="the number of iterations that controls how often to update the EMA model (default: 32)",
485
+ )
486
+ parser.add_argument(
487
+ "--model-ema-decay",
488
+ type=float,
489
+ default=0.99998,
490
+ help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
491
+ )
492
+ parser.add_argument(
493
+ "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
494
+ )
495
+ parser.add_argument(
496
+ "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
497
+ )
498
+ parser.add_argument(
499
+ "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
500
+ )
501
+ parser.add_argument(
502
+ "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
503
+ )
504
+ parser.add_argument(
505
+ "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
506
+ )
507
+ parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
508
+ parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
509
+ parser.add_argument(
510
+ "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
511
+ )
512
+ parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
513
+
514
+ parser.add_argument("--apply-trp", action="store_true", help="enable applying trp")
515
+ parser.add_argument("--trp-depths", type=int, help="trp depth")
516
+ parser.add_argument("--trp-p", type=float, help="trp p")
517
+ parser.add_argument("--trp-lambdas", nargs="+", type=float, help="trp lambdas")
518
+
519
+ return parser
520
+
521
+
522
+ if __name__ == "__main__":
523
+ args = get_args_parser().parse_args()
524
+ main(args)
hpo-examples/image-classification/train_quantization.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import datetime
3
+ import os
4
+ import time
5
+
6
+ import torch
7
+ import torch.ao.quantization
8
+ import torch.utils.data
9
+ import torchvision
10
+ import utils
11
+ from torch import nn
12
+ from train import evaluate, load_data, train_one_epoch
13
+
14
+
15
+ def main(args):
16
+ if args.output_dir:
17
+ utils.mkdir(args.output_dir)
18
+
19
+ utils.init_distributed_mode(args)
20
+ print(args)
21
+
22
+ if args.post_training_quantize and args.distributed:
23
+ raise RuntimeError("Post training quantization example should not be performed on distributed mode")
24
+
25
+ # Set backend engine to ensure that quantized model runs on the correct kernels
26
+ if args.backend not in torch.backends.quantized.supported_engines:
27
+ raise RuntimeError("Quantized backend not supported: " + str(args.backend))
28
+ torch.backends.quantized.engine = args.backend
29
+
30
+ device = torch.device(args.device)
31
+ torch.backends.cudnn.benchmark = True
32
+
33
+ # Data loading code
34
+ print("Loading data")
35
+ train_dir = os.path.join(args.data_path, "train")
36
+ val_dir = os.path.join(args.data_path, "val")
37
+
38
+ dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
39
+ data_loader = torch.utils.data.DataLoader(
40
+ dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
41
+ )
42
+
43
+ data_loader_test = torch.utils.data.DataLoader(
44
+ dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
45
+ )
46
+
47
+ print("Creating model", args.model)
48
+ # when training quantized models, we always start from a pre-trained fp32 reference model
49
+ prefix = "quantized_"
50
+ model_name = args.model
51
+ if not model_name.startswith(prefix):
52
+ model_name = prefix + model_name
53
+ model = torchvision.models.get_model(model_name, weights=args.weights, quantize=args.test_only)
54
+ model.to(device)
55
+
56
+ if not (args.test_only or args.post_training_quantize):
57
+ model.fuse_model(is_qat=True)
58
+ model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
59
+ torch.ao.quantization.prepare_qat(model, inplace=True)
60
+
61
+ if args.distributed and args.sync_bn:
62
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
63
+
64
+ optimizer = torch.optim.SGD(
65
+ model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
66
+ )
67
+
68
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
69
+
70
+ criterion = nn.CrossEntropyLoss()
71
+ model_without_ddp = model
72
+ if args.distributed:
73
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
74
+ model_without_ddp = model.module
75
+
76
+ if args.resume:
77
+ checkpoint = torch.load(args.resume, map_location="cpu")
78
+ model_without_ddp.load_state_dict(checkpoint["model"])
79
+ optimizer.load_state_dict(checkpoint["optimizer"])
80
+ lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
81
+ args.start_epoch = checkpoint["epoch"] + 1
82
+
83
+ if args.post_training_quantize:
84
+ # perform calibration on a subset of the training dataset
85
+ # for that, create a subset of the training dataset
86
+ ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches)))
87
+ data_loader_calibration = torch.utils.data.DataLoader(
88
+ ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
89
+ )
90
+ model.eval()
91
+ model.fuse_model(is_qat=False)
92
+ model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
93
+ torch.ao.quantization.prepare(model, inplace=True)
94
+ # Calibrate first
95
+ print("Calibrating")
96
+ evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
97
+ torch.ao.quantization.convert(model, inplace=True)
98
+ if args.output_dir:
99
+ print("Saving quantized model")
100
+ if utils.is_main_process():
101
+ torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth"))
102
+ print("Evaluating post-training quantized model")
103
+ evaluate(model, criterion, data_loader_test, device=device)
104
+ return
105
+
106
+ if args.test_only:
107
+ evaluate(model, criterion, data_loader_test, device=device)
108
+ return
109
+
110
+ model.apply(torch.ao.quantization.enable_observer)
111
+ model.apply(torch.ao.quantization.enable_fake_quant)
112
+ start_time = time.time()
113
+ for epoch in range(args.start_epoch, args.epochs):
114
+ if args.distributed:
115
+ train_sampler.set_epoch(epoch)
116
+ print("Starting training for epoch", epoch)
117
+ train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
118
+ lr_scheduler.step()
119
+ with torch.inference_mode():
120
+ if epoch >= args.num_observer_update_epochs:
121
+ print("Disabling observer for subseq epochs, epoch = ", epoch)
122
+ model.apply(torch.ao.quantization.disable_observer)
123
+ if epoch >= args.num_batch_norm_update_epochs:
124
+ print("Freezing BN for subseq epochs, epoch = ", epoch)
125
+ model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
126
+ print("Evaluate QAT model")
127
+
128
+ evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT")
129
+ quantized_eval_model = copy.deepcopy(model_without_ddp)
130
+ quantized_eval_model.eval()
131
+ quantized_eval_model.to(torch.device("cpu"))
132
+ torch.ao.quantization.convert(quantized_eval_model, inplace=True)
133
+
134
+ print("Evaluate Quantized model")
135
+ evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
136
+
137
+ model.train()
138
+
139
+ if args.output_dir:
140
+ checkpoint = {
141
+ "model": model_without_ddp.state_dict(),
142
+ "eval_model": quantized_eval_model.state_dict(),
143
+ "optimizer": optimizer.state_dict(),
144
+ "lr_scheduler": lr_scheduler.state_dict(),
145
+ "epoch": epoch,
146
+ "args": args,
147
+ }
148
+ utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
149
+ utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
150
+ print("Saving models after epoch ", epoch)
151
+
152
+ total_time = time.time() - start_time
153
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
154
+ print(f"Training time {total_time_str}")
155
+
156
+
157
+ def get_args_parser(add_help=True):
158
+ import argparse
159
+
160
+ parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help)
161
+
162
+ parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
163
+ parser.add_argument("--model", default="mobilenet_v2", type=str, help="model name")
164
+ parser.add_argument("--backend", default="qnnpack", type=str, help="fbgemm or qnnpack")
165
+ parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
166
+
167
+ parser.add_argument(
168
+ "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
169
+ )
170
+ parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation")
171
+ parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
172
+ parser.add_argument(
173
+ "--num-observer-update-epochs",
174
+ default=4,
175
+ type=int,
176
+ metavar="N",
177
+ help="number of total epochs to update observers",
178
+ )
179
+ parser.add_argument(
180
+ "--num-batch-norm-update-epochs",
181
+ default=3,
182
+ type=int,
183
+ metavar="N",
184
+ help="number of total epochs to update batch norm stats",
185
+ )
186
+ parser.add_argument(
187
+ "--num-calibration-batches",
188
+ default=32,
189
+ type=int,
190
+ metavar="N",
191
+ help="number of batches of training set for \
192
+ observer calibration ",
193
+ )
194
+
195
+ parser.add_argument(
196
+ "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
197
+ )
198
+ parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate")
199
+ parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
200
+ parser.add_argument(
201
+ "--wd",
202
+ "--weight-decay",
203
+ default=1e-4,
204
+ type=float,
205
+ metavar="W",
206
+ help="weight decay (default: 1e-4)",
207
+ dest="weight_decay",
208
+ )
209
+ parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
210
+ parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
211
+ parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
212
+ parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
213
+ parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
214
+ parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
215
+ parser.add_argument(
216
+ "--cache-dataset",
217
+ dest="cache_dataset",
218
+ help="Cache the datasets for quicker initialization. \
219
+ It also serializes the transforms",
220
+ action="store_true",
221
+ )
222
+ parser.add_argument(
223
+ "--sync-bn",
224
+ dest="sync_bn",
225
+ help="Use sync batch norm",
226
+ action="store_true",
227
+ )
228
+ parser.add_argument(
229
+ "--test-only",
230
+ dest="test_only",
231
+ help="Only test the model",
232
+ action="store_true",
233
+ )
234
+ parser.add_argument(
235
+ "--post-training-quantize",
236
+ dest="post_training_quantize",
237
+ help="Post training quantize the model",
238
+ action="store_true",
239
+ )
240
+
241
+ # distributed training parameters
242
+ parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
243
+ parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
244
+
245
+ parser.add_argument(
246
+ "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
247
+ )
248
+ parser.add_argument(
249
+ "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
250
+ )
251
+ parser.add_argument(
252
+ "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
253
+ )
254
+ parser.add_argument(
255
+ "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
256
+ )
257
+ parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
258
+ parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
259
+
260
+ return parser
261
+
262
+
263
+ if __name__ == "__main__":
264
+ args = get_args_parser().parse_args()
265
+ main(args)
hpo-examples/image-classification/transforms.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torchvision.transforms import functional as F
7
+
8
+
9
+ class RandomMixup(torch.nn.Module):
10
+ """Randomly apply Mixup to the provided batch and targets.
11
+ The class implements the data augmentations as described in the paper
12
+ `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
13
+
14
+ Args:
15
+ num_classes (int): number of classes used for one-hot encoding.
16
+ p (float): probability of the batch being transformed. Default value is 0.5.
17
+ alpha (float): hyperparameter of the Beta distribution used for mixup.
18
+ Default value is 1.0.
19
+ inplace (bool): boolean to make this transform inplace. Default set to False.
20
+ """
21
+
22
+ def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
23
+ super().__init__()
24
+
25
+ if num_classes < 1:
26
+ raise ValueError(
27
+ f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
28
+ )
29
+
30
+ if alpha <= 0:
31
+ raise ValueError("Alpha param can't be zero.")
32
+
33
+ self.num_classes = num_classes
34
+ self.p = p
35
+ self.alpha = alpha
36
+ self.inplace = inplace
37
+
38
+ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
39
+ """
40
+ Args:
41
+ batch (Tensor): Float tensor of size (B, C, H, W)
42
+ target (Tensor): Integer tensor of size (B, )
43
+
44
+ Returns:
45
+ Tensor: Randomly transformed batch.
46
+ """
47
+ if batch.ndim != 4:
48
+ raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
49
+ if target.ndim != 1:
50
+ raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
51
+ if not batch.is_floating_point():
52
+ raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
53
+ if target.dtype != torch.int64:
54
+ raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
55
+
56
+ if not self.inplace:
57
+ batch = batch.clone()
58
+ target = target.clone()
59
+
60
+ if target.ndim == 1:
61
+ target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
62
+
63
+ if torch.rand(1).item() >= self.p:
64
+ return batch, target
65
+
66
+ # It's faster to roll the batch by one instead of shuffling it to create image pairs
67
+ batch_rolled = batch.roll(1, 0)
68
+ target_rolled = target.roll(1, 0)
69
+
70
+ # Implemented as on mixup paper, page 3.
71
+ lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
72
+ batch_rolled.mul_(1.0 - lambda_param)
73
+ batch.mul_(lambda_param).add_(batch_rolled)
74
+
75
+ target_rolled.mul_(1.0 - lambda_param)
76
+ target.mul_(lambda_param).add_(target_rolled)
77
+
78
+ return batch, target
79
+
80
+ def __repr__(self) -> str:
81
+ s = (
82
+ f"{self.__class__.__name__}("
83
+ f"num_classes={self.num_classes}"
84
+ f", p={self.p}"
85
+ f", alpha={self.alpha}"
86
+ f", inplace={self.inplace}"
87
+ f")"
88
+ )
89
+ return s
90
+
91
+
92
+ class RandomCutmix(torch.nn.Module):
93
+ """Randomly apply Cutmix to the provided batch and targets.
94
+ The class implements the data augmentations as described in the paper
95
+ `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
96
+ <https://arxiv.org/abs/1905.04899>`_.
97
+
98
+ Args:
99
+ num_classes (int): number of classes used for one-hot encoding.
100
+ p (float): probability of the batch being transformed. Default value is 0.5.
101
+ alpha (float): hyperparameter of the Beta distribution used for cutmix.
102
+ Default value is 1.0.
103
+ inplace (bool): boolean to make this transform inplace. Default set to False.
104
+ """
105
+
106
+ def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
107
+ super().__init__()
108
+ if num_classes < 1:
109
+ raise ValueError("Please provide a valid positive value for the num_classes.")
110
+ if alpha <= 0:
111
+ raise ValueError("Alpha param can't be zero.")
112
+
113
+ self.num_classes = num_classes
114
+ self.p = p
115
+ self.alpha = alpha
116
+ self.inplace = inplace
117
+
118
+ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
119
+ """
120
+ Args:
121
+ batch (Tensor): Float tensor of size (B, C, H, W)
122
+ target (Tensor): Integer tensor of size (B, )
123
+
124
+ Returns:
125
+ Tensor: Randomly transformed batch.
126
+ """
127
+ if batch.ndim != 4:
128
+ raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
129
+ if target.ndim != 1:
130
+ raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
131
+ if not batch.is_floating_point():
132
+ raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
133
+ if target.dtype != torch.int64:
134
+ raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
135
+
136
+ if not self.inplace:
137
+ batch = batch.clone()
138
+ target = target.clone()
139
+
140
+ if target.ndim == 1:
141
+ target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
142
+
143
+ if torch.rand(1).item() >= self.p:
144
+ return batch, target
145
+
146
+ # It's faster to roll the batch by one instead of shuffling it to create image pairs
147
+ batch_rolled = batch.roll(1, 0)
148
+ target_rolled = target.roll(1, 0)
149
+
150
+ # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
151
+ lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
152
+ _, H, W = F.get_dimensions(batch)
153
+
154
+ r_x = torch.randint(W, (1,))
155
+ r_y = torch.randint(H, (1,))
156
+
157
+ r = 0.5 * math.sqrt(1.0 - lambda_param)
158
+ r_w_half = int(r * W)
159
+ r_h_half = int(r * H)
160
+
161
+ x1 = int(torch.clamp(r_x - r_w_half, min=0))
162
+ y1 = int(torch.clamp(r_y - r_h_half, min=0))
163
+ x2 = int(torch.clamp(r_x + r_w_half, max=W))
164
+ y2 = int(torch.clamp(r_y + r_h_half, max=H))
165
+
166
+ batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
167
+ lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
168
+
169
+ target_rolled.mul_(1.0 - lambda_param)
170
+ target.mul_(lambda_param).add_(target_rolled)
171
+
172
+ return batch, target
173
+
174
+ def __repr__(self) -> str:
175
+ s = (
176
+ f"{self.__class__.__name__}("
177
+ f"num_classes={self.num_classes}"
178
+ f", p={self.p}"
179
+ f", alpha={self.alpha}"
180
+ f", inplace={self.inplace}"
181
+ f")"
182
+ )
183
+ return s
hpo-examples/image-classification/trplib.py ADDED
@@ -0,0 +1,1181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ from torch.nn import functional as F
4
+
5
+ from torchvision.models.mobilenetv2 import MobileNetV2
6
+ from torchvision.models.resnet import ResNet
7
+ from torchvision.models.efficientnet import EfficientNet
8
+ from torchvision.models.vision_transformer import VisionTransformer
9
+ from torchvision.models.segmentation.fcn import FCN
10
+ from torchvision.models.segmentation.deeplabv3 import DeepLabV3
11
+
12
+ import transformers
13
+ from transformers.modeling_outputs import SequenceClassifierOutput, QuestionAnsweringModelOutput, CausalLMOutput, Seq2SeqLMOutput
14
+
15
+ from typing import Optional, Tuple, List, Union, Callable
16
+ from collections import OrderedDict
17
+ import types
18
+
19
+
20
+ def trp_criterion(trp_blocks: nn.ModuleList, shared_head: Callable, criterion: Callable, lambdas: List[float], hidden_states: Tensor, logits: Tensor, targets: Tensor, loss_normalization=False):
21
+ loss, mask = criterion(logits, targets)
22
+ if loss_normalization:
23
+ coeff = loss.detach()
24
+
25
+ embeds = [hidden_states]
26
+ predictions = []
27
+ for k, c in enumerate(lambdas):
28
+ embeds.append(trp_blocks[k](embeds[-1]))
29
+ predictions.append(shared_head(embeds[-1]))
30
+ replica_loss, mask = criterion(predictions[-1], targets, mask)
31
+ loss += c * replica_loss
32
+
33
+ if loss_normalization:
34
+ with torch.no_grad():
35
+ coeff = torch.exp(coeff) / torch.exp(loss.detach())
36
+ loss = coeff * loss
37
+
38
+ return loss
39
+
40
+
41
+ class TPBlock(nn.Module):
42
+ def __init__(self, depths: int, in_features: int, p: float, dim=-1):
43
+ super(TPBlock, self).__init__()
44
+
45
+ self.dropout = nn.Dropout(p)
46
+
47
+ self.cdim = dim
48
+
49
+ blocks = []
50
+ for _ in range(depths):
51
+ blocks.append(nn.Linear(in_features, in_features))
52
+ nn.init.constant_(blocks[-1].weight, 0.0)
53
+ nn.init.constant_(blocks[-1].bias, 0.0)
54
+ blocks.append(nn.ReLU())
55
+ self.blocks = nn.Sequential(*blocks)
56
+
57
+ def forward(self, x):
58
+ x = self.dropout(x)
59
+ if self.cdim == -1:
60
+ x = x + self.blocks(x)
61
+ else:
62
+ x = x + torch.movedim(self.blocks(torch.movedim(x, self.cdim, -1)), -1, self.cdim)
63
+ return x
64
+
65
+
66
+ class Config:
67
+ @staticmethod
68
+ def gen_criterion(*args, **kwargs):
69
+ def func(input, target, mask=None):
70
+ """
71
+ Args:
72
+ input (Tensor): Input tensor.
73
+ target (Tensor): Target labels.
74
+
75
+ Returns:
76
+ loss (Tensor): Scalar tensor representing the loss.
77
+ mask (Tensor): Boolean mask tensor with the same shape of target.
78
+ """
79
+ pass
80
+ return func
81
+
82
+ @staticmethod
83
+ def gen_shared_head(*args, **kwargs):
84
+ def func(hidden_states):
85
+ """
86
+ Args:
87
+ hidden_states (Tensor): Hidden States tensor.
88
+
89
+ Returns:
90
+ logits (Tensor): Logits tensor.
91
+ """
92
+ pass
93
+ return func
94
+
95
+ @staticmethod
96
+ def forward(*args, **kwargs):
97
+ pass
98
+
99
+
100
+ # Wav2Vec2 for Audio Classification
101
+ class Wav2Vec2ForSequenceClassificationConfig(Config):
102
+ _HIDDEN_STATES_START_POSITION = 2
103
+
104
+ @staticmethod
105
+ def gen_criterion():
106
+ def func(input, target, mask=None):
107
+ """
108
+ Args:
109
+ input (Tensor): Input tensor of shape [B, C].
110
+ target (Tensor): Target labels of shape [B].
111
+
112
+ Returns:
113
+ loss (Tensor): Scalar tensor representing the loss.
114
+ mask (Tensor): Boolean mask tensor of shape [B].
115
+ """
116
+ if mask is None:
117
+ mask = torch.ones_like(target, dtype=torch.float32, device=target.device)
118
+
119
+ unmasked_loss = F.cross_entropy(input, target, reduction="none")
120
+ loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
121
+
122
+ with torch.no_grad():
123
+ mask = mask * torch.eq(torch.argmax(input, dim=1), target).to(input.dtype)
124
+
125
+ return loss, mask
126
+ return func
127
+
128
+ @staticmethod
129
+ def gen_shared_head(self, attention_mask):
130
+ def func(hidden_states):
131
+ """
132
+ Args:
133
+ hidden_states (Tensor): Hidden States of shape [B, L, hidden_units].
134
+
135
+ Returns:
136
+ logits (Tensor): Logits tensor of shape [B, C].
137
+ """
138
+ _hidden_states = self.projector(hidden_states)
139
+ if attention_mask is None:
140
+ pooled_output = _hidden_states.mean(dim=1)
141
+ else:
142
+ padding_mask = self._get_feature_vector_attention_mask(_hidden_states.shape[1], attention_mask)
143
+ expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, _hidden_states.shape[2])
144
+ _hidden_states[~expand_padding_mask] = 0.0
145
+ pooled_output = _hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
146
+
147
+ logits = self.classifier(pooled_output)
148
+ return logits
149
+ return func
150
+
151
+ @staticmethod
152
+ def gen_forward(lambdas, loss_normalization=False):
153
+ def func(
154
+ self,
155
+ input_values: Optional[torch.Tensor],
156
+ attention_mask: Optional[torch.Tensor] = None,
157
+ output_attentions: Optional[bool] = None,
158
+ output_hidden_states: Optional[bool] = None,
159
+ return_dict: Optional[bool] = None,
160
+ labels: Optional[torch.Tensor] = None,
161
+ ) -> Union[Tuple, SequenceClassifierOutput]:
162
+ r"""
163
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
164
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
165
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
166
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
167
+ """
168
+
169
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
170
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
171
+
172
+ outputs = self.wav2vec2(
173
+ input_values,
174
+ attention_mask=attention_mask,
175
+ output_attentions=output_attentions,
176
+ output_hidden_states=output_hidden_states,
177
+ return_dict=return_dict,
178
+ )
179
+
180
+ if self.config.use_weighted_layer_sum:
181
+ hidden_states = outputs[Wav2Vec2ForSequenceClassificationConfig._HIDDEN_STATES_START_POSITION]
182
+ hidden_states = torch.stack(hidden_states, dim=1)
183
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
184
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
185
+ else:
186
+ hidden_states = outputs[0]
187
+
188
+ _hidden_states = self.projector(hidden_states)
189
+ if attention_mask is None:
190
+ pooled_output = _hidden_states.mean(dim=1)
191
+ else:
192
+ padding_mask = self._get_feature_vector_attention_mask(_hidden_states.shape[1], attention_mask)
193
+ expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, _hidden_states.shape[2])
194
+ _hidden_states[~expand_padding_mask] = 0.0
195
+ pooled_output = _hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
196
+
197
+ logits = self.classifier(pooled_output)
198
+
199
+ loss = None
200
+ if labels is not None:
201
+ shared_head = Wav2Vec2ForSequenceClassificationConfig.gen_shared_head(self, attention_mask)
202
+ criterion = Wav2Vec2ForSequenceClassificationConfig.gen_criterion()
203
+ loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_states, logits.view(-1, self.config.num_labels), labels.view(-1), loss_normalization) # NOTE: Apply TRP!
204
+
205
+ if not return_dict:
206
+ output = (logits,) + outputs[Wav2Vec2ForSequenceClassificationConfig._HIDDEN_STATES_START_POSITION:]
207
+ return ((loss,) + output) if loss is not None else output
208
+
209
+ return SequenceClassifierOutput(
210
+ loss=loss,
211
+ logits=logits,
212
+ hidden_states=outputs.hidden_states,
213
+ attentions=outputs.attentions,
214
+ )
215
+ return func
216
+
217
+
218
+ # MobileNetV2 for Image Classification
219
+ class MobileNetV2Config(Config):
220
+ @staticmethod
221
+ def gen_criterion(label_smoothing=0.0, top_k=1):
222
+ def func(input, target, mask=None):
223
+ """
224
+ Args:
225
+ input (Tensor): Input tensor of shape [B, C].
226
+ target (Tensor): Target labels of shape [B] or [B, C].
227
+
228
+ Returns:
229
+ loss (Tensor): Scalar tensor representing the loss.
230
+ mask (Tensor): Boolean mask tensor of shape [B].
231
+ """
232
+ label = torch.argmax(target, dim=1) if label_smoothing > 0.0 else target
233
+
234
+ unmasked_loss = F.cross_entropy(input, label, reduction="none", label_smoothing=label_smoothing)
235
+ if mask is None:
236
+ mask = torch.ones_like(unmasked_loss, dtype=torch.float32, device=target.device)
237
+ loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
238
+
239
+ with torch.no_grad():
240
+ topk_values, topk_indices = torch.topk(input, top_k, dim=-1)
241
+ mask = mask * torch.eq(topk_indices, label[:, None]).any(dim=-1).to(input.dtype)
242
+
243
+ return loss, mask
244
+ return func
245
+
246
+ @staticmethod
247
+ def gen_shared_head(self):
248
+ def func(x):
249
+ """
250
+ Args:
251
+ x (Tensor): Hidden States tensor of shape [B, hidden_units].
252
+
253
+ Returns:
254
+ logits (Tensor): Logits tensor of shape [B, C].
255
+ """
256
+ logits = self.classifier(x)
257
+ return logits
258
+ return func
259
+
260
+ @staticmethod
261
+ def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
262
+ def func(self, images: Tensor, targets=None):
263
+ x = self.features(images)
264
+ x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
265
+ x = torch.flatten(x, 1)
266
+ logits = self.classifier(x)
267
+
268
+ if self.training:
269
+ torch._assert(targets is not None, "targets should not be none when in training mode")
270
+ shared_head = MobileNetV2Config.gen_shared_head(self)
271
+ criterion = MobileNetV2Config.gen_criterion(label_smoothing, top_k)
272
+ loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
273
+ return logits, loss
274
+ return logits
275
+ return func
276
+
277
+
278
+ # ResNet for Image Classification
279
+ class ResNetConfig(MobileNetV2Config):
280
+ @staticmethod
281
+ def gen_shared_head(self):
282
+ def func(x):
283
+ """
284
+ Args:
285
+ x (Tensor): Hidden States tensor of shape [B, hidden_units].
286
+
287
+ Returns:
288
+ logits (Tensor): Logits tensor of shape [B, C].
289
+ """
290
+ logits = self.fc(x)
291
+ return logits
292
+ return func
293
+
294
+ @staticmethod
295
+ def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
296
+ def func(self, images: Tensor, targets=None):
297
+ x = self.conv1(images)
298
+ x = self.bn1(x)
299
+ x = self.relu(x)
300
+ x = self.maxpool(x)
301
+
302
+ x = self.layer1(x)
303
+ x = self.layer2(x)
304
+ x = self.layer3(x)
305
+ x = self.layer4(x)
306
+
307
+ x = self.avgpool(x)
308
+ x = torch.flatten(x, 1)
309
+ logits = self.fc(x)
310
+
311
+ if self.training:
312
+ torch._assert(targets is not None, "targets should not be none when in training mode")
313
+ shared_head = ResNetConfig.gen_shared_head(self)
314
+ criterion = ResNetConfig.gen_criterion(label_smoothing, top_k)
315
+ loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
316
+ return logits, loss
317
+ return logits
318
+ return func
319
+
320
+
321
+ # EfficientNet for Image Classification
322
+ class EfficientNetConfig(MobileNetV2Config):
323
+ @staticmethod
324
+ def gen_shared_head(self):
325
+ def func(x):
326
+ """
327
+ Args:
328
+ x (Tensor): Hidden States tensor of shape [B, hidden_units].
329
+
330
+ Returns:
331
+ logits (Tensor): Logits tensor of shape [B, C].
332
+ """
333
+ logits = self.classifier(x)
334
+ return logits
335
+ return func
336
+
337
+ @staticmethod
338
+ def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
339
+ def func(self, images: Tensor, targets=None):
340
+ x = self.features(images)
341
+ x = self.avgpool(x)
342
+ x = torch.flatten(x, 1)
343
+ logits = self.classifier(x)
344
+
345
+ if self.training:
346
+ torch._assert(targets is not None, "targets should not be none when in training mode")
347
+ shared_head = EfficientNetConfig.gen_shared_head(self)
348
+ criterion = EfficientNetConfig.gen_criterion(label_smoothing, top_k)
349
+ loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
350
+ return logits, loss
351
+ return logits
352
+ return func
353
+
354
+
355
+ # ViT for Image Classification
356
+ class VisionTransformerConfig(MobileNetV2Config):
357
+ @staticmethod
358
+ def gen_shared_head(self):
359
+ def func(x):
360
+ """
361
+ Args:
362
+ x (Tensor): Hidden States tensor of shape [B, hidden_units].
363
+
364
+ Returns:
365
+ logits (Tensor): Logits tensor of shape [B, C].
366
+ """
367
+ logits = self.heads(x)
368
+ return logits
369
+ return func
370
+
371
+ @staticmethod
372
+ def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
373
+ def func(self, images: Tensor, targets=None):
374
+ x = self._process_input(images)
375
+ n = x.shape[0]
376
+ batch_class_token = self.class_token.expand(n, -1, -1)
377
+ x = torch.cat([batch_class_token, x], dim=1)
378
+ x = self.encoder(x)
379
+ x = x[:, 0]
380
+
381
+ logits = self.heads(x)
382
+
383
+ if self.training:
384
+ torch._assert(targets is not None, "targets should not be none when in training mode")
385
+ shared_head = VisionTransformerConfig.gen_shared_head(self)
386
+ criterion = VisionTransformerConfig.gen_criterion(label_smoothing, top_k)
387
+ loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
388
+ return logits, loss
389
+ return logits
390
+ return func
391
+
392
+
393
+ # Bert for Question Answering
394
+ class BertForQuestionAnsweringConfig(Config):
395
+ @staticmethod
396
+ def gen_criterion(top_k=1):
397
+ def func(input, target: List[Tensor], mask=None):
398
+ """
399
+ Args:
400
+ input (Tensor): Input tensor of shape [B, C, 2].
401
+ target (List[Tensor]):
402
+ Start Positions of shape [B].
403
+ End Positions of shape [B].
404
+
405
+ Returns:
406
+ loss (Tensor): Scalar tensor representing the loss.
407
+ mask (Tensor): Boolean mask tensor of shape [B].
408
+ """
409
+ start_positions, end_positions = target
410
+
411
+ if mask is None:
412
+ mask = torch.ones_like(start_positions, dtype=torch.float32, device=start_positions.device)
413
+
414
+ start_logits, end_logits = input.split(1, dim=-1)
415
+ start_logits = start_logits.squeeze(-1).contiguous()
416
+ end_logits = end_logits.squeeze(-1).contiguous()
417
+
418
+ # If we are on multi-GPU, split add a dimension
419
+ if len(start_positions.size()) > 1:
420
+ start_positions = start_positions.squeeze(-1)
421
+ if len(end_positions.size()) > 1:
422
+ end_positions = end_positions.squeeze(-1)
423
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
424
+ ignored_index = start_logits.size(1)
425
+ start_positions = start_positions.clamp(0, ignored_index)
426
+ end_positions = end_positions.clamp(0, ignored_index)
427
+
428
+ masked_start_losses = F.cross_entropy(start_logits, start_positions, ignore_index=ignored_index, reduction="none")
429
+ start_loss = torch.sum(mask * masked_start_losses) / (torch.sum(mask) + 1e-6)
430
+ masked_end_losses = F.cross_entropy(end_logits, end_positions, ignore_index=ignored_index, reduction="none")
431
+ end_loss = torch.sum(mask * masked_end_losses) / (torch.sum(mask) + 1e-6)
432
+
433
+ with torch.no_grad():
434
+ topk_values, topk_indices = torch.topk(start_logits, top_k, dim=1)
435
+ mask = mask * torch.eq(topk_indices, start_positions[:, None]).any(dim=1).to(start_logits.dtype)
436
+ topk_values, topk_indices = torch.topk(end_logits, top_k, dim=1)
437
+ mask = mask * torch.eq(topk_indices, end_positions[:, None]).any(dim=1).to(end_logits.dtype)
438
+
439
+ return (start_loss + end_loss) / 2, mask
440
+ return func
441
+
442
+ @staticmethod
443
+ def gen_shared_head(self):
444
+ def func(hidden_states):
445
+ """
446
+ Args:
447
+ hidden_states (Tensor): Hidden States of shape [B, C, hidden_units].
448
+
449
+ Returns:
450
+ logits (Tensor): Logits tensor of shape [B, C, 2].
451
+ """
452
+ logits = self.qa_outputs(hidden_states)
453
+ return logits
454
+ return func
455
+
456
+ @staticmethod
457
+ def gen_forward(lambdas, loss_normalization=True, top_k=1):
458
+ def func(
459
+ self,
460
+ input_ids: Optional[torch.Tensor] = None,
461
+ attention_mask: Optional[torch.Tensor] = None,
462
+ token_type_ids: Optional[torch.Tensor] = None,
463
+ position_ids: Optional[torch.Tensor] = None,
464
+ head_mask: Optional[torch.Tensor] = None,
465
+ inputs_embeds: Optional[torch.Tensor] = None,
466
+ start_positions: Optional[torch.Tensor] = None,
467
+ end_positions: Optional[torch.Tensor] = None,
468
+ output_attentions: Optional[bool] = None,
469
+ output_hidden_states: Optional[bool] = None,
470
+ return_dict: Optional[bool] = None,
471
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
472
+ r"""
473
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
474
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
475
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
476
+ are not taken into account for computing the loss.
477
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
478
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
479
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
480
+ are not taken into account for computing the loss.
481
+ """
482
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
483
+
484
+ outputs = self.bert(
485
+ input_ids,
486
+ attention_mask=attention_mask,
487
+ token_type_ids=token_type_ids,
488
+ position_ids=position_ids,
489
+ head_mask=head_mask,
490
+ inputs_embeds=inputs_embeds,
491
+ output_attentions=output_attentions,
492
+ output_hidden_states=output_hidden_states,
493
+ return_dict=return_dict,
494
+ )
495
+
496
+ sequence_output = outputs[0]
497
+
498
+ logits = self.qa_outputs(sequence_output)
499
+ start_logits, end_logits = logits.split(1, dim=-1)
500
+ start_logits = start_logits.squeeze(-1).contiguous()
501
+ end_logits = end_logits.squeeze(-1).contiguous()
502
+
503
+ total_loss = None
504
+ if start_positions is not None and end_positions is not None:
505
+ shared_head = BertForQuestionAnsweringConfig.gen_shared_head(self)
506
+ criterion = BertForQuestionAnsweringConfig.gen_criterion()
507
+ total_loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, sequence_output, logits, [start_positions, end_positions], loss_normalization) # NOTE: Apply TRP!
508
+
509
+ if not return_dict:
510
+ output = (start_logits, end_logits) + outputs[2:]
511
+ return ((total_loss,) + output) if total_loss is not None else output
512
+
513
+ return QuestionAnsweringModelOutput(
514
+ loss=total_loss,
515
+ start_logits=start_logits,
516
+ end_logits=end_logits,
517
+ hidden_states=outputs.hidden_states,
518
+ attentions=outputs.attentions,
519
+ )
520
+ return func
521
+
522
+
523
+ # FCN for Semantic Segmentation
524
+ class FCNConfig(Config):
525
+ @staticmethod
526
+ def gen_criterion(top_k=1):
527
+ def func(input, target, mask=None):
528
+ """
529
+ Args:
530
+ input Tensor: input tensor of shape [B, C, H, W].
531
+ target (Tensor): Target labels of shape [B, H, W].
532
+
533
+ Returns:
534
+ loss (Tensor): Scalar tensor representing the loss.
535
+ mask (Tensor): Boolean mask tensor of shape [B, H, W].
536
+ """
537
+ if mask is None:
538
+ mask = torch.ones_like(target, dtype=torch.float32, device=target.device)
539
+
540
+ masked_loss = F.cross_entropy(input, target, ignore_index=255, reduction="none")
541
+ loss = torch.sum(mask * masked_loss) / (torch.sum(mask) + 1e-6)
542
+
543
+ with torch.no_grad():
544
+ topk_values, topk_indices = torch.topk(input, top_k, dim=1)
545
+ mask = mask * torch.eq(topk_indices, target[:, None, :, :]).any(dim=1).to(input.dtype)
546
+ # mask = mask * torch.eq(torch.argmax(x, dim=1), target).to(x.dtype)
547
+
548
+ return loss, mask
549
+ return func
550
+
551
+ @staticmethod
552
+ def gen_out_shared_head(self, input_shape):
553
+ def func(features):
554
+ """
555
+ Args:
556
+ features (Tensor): features tensor of shape [B, hidden_units, H, W].
557
+
558
+ Returns:
559
+ result (Tensors): result tensor of shape [B, C, H, W].
560
+ """
561
+ x = self.classifier(features)
562
+ result = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
563
+ return result
564
+ return func
565
+
566
+ @staticmethod
567
+ def gen_aux_shared_head(self, input_shape):
568
+ def func(features):
569
+ """
570
+ Args:
571
+ features (Tensor): features tensor of shape [B, hidden_units, H, W].
572
+
573
+ Returns:
574
+ result (Tensors): result tensor of shape [B, C, H, W].
575
+ """
576
+ x = self.aux_classifier(features)
577
+ result = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
578
+ return result
579
+ return func
580
+
581
+ @staticmethod
582
+ def gen_forward(lambdas, loss_normalization=True, top_k=1):
583
+ def func(self, images: Tensor, targets=None):
584
+ input_shape = images.shape[-2:]
585
+ # contract: features is a dict of tensors
586
+ features = self.backbone(images)
587
+
588
+ result = OrderedDict()
589
+ x = features["out"]
590
+ x = self.classifier(x)
591
+ x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
592
+ result["out"] = x
593
+
594
+ if self.aux_classifier is not None:
595
+ x = features["aux"]
596
+ x = self.aux_classifier(x)
597
+ x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
598
+ result["aux"] = x
599
+
600
+ if self.training:
601
+ torch._assert(targets is not None, "targets should not be none when in training mode")
602
+ out_shared_head = FCNConfig.gen_out_shared_head(self, input_shape)
603
+ aux_shared_head = FCNConfig.gen_aux_shared_head(self, input_shape)
604
+ criterion = FCNConfig.gen_criterion(top_k)
605
+ out_loss = trp_criterion(self.out_trp_blocks, out_shared_head, criterion, lambdas, features["out"], result["out"], targets, loss_normalization)
606
+ aux_loss = trp_criterion(self.aux_trp_blocks, aux_shared_head, criterion, lambdas, features["aux"], result["aux"], targets, loss_normalization)
607
+ loss = out_loss + 0.5 * aux_loss
608
+ return result, loss
609
+ return result
610
+ return func
611
+
612
+
613
+ # DeepLabV3Config for Semantic Segmentation
614
+ class DeepLabV3Config(FCNConfig):
615
+ pass
616
+
617
+
618
+ # Bert for Text Classification
619
+ class BertForSequenceClassificationConfig(Config):
620
+ @staticmethod
621
+ def gen_criterion():
622
+ def func(input, target, mask=None):
623
+ """
624
+ Args:
625
+ input (Tensor): Input tensor of shape [B, C].
626
+ target (Tensor): Target labels of shape [B].
627
+
628
+ Returns:
629
+ loss (Tensor): Scalar tensor representing the loss.
630
+ mask (Tensor): Boolean mask tensor of shape [B].
631
+ """
632
+ if mask is None:
633
+ mask = torch.ones_like(target, dtype=torch.float32, device=target.device)
634
+
635
+ unmasked_loss = F.cross_entropy(input, target, reduction="none")
636
+ loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
637
+
638
+ with torch.no_grad():
639
+ mask = mask * torch.eq(torch.argmax(input, dim=1), target).to(input.dtype)
640
+
641
+ return loss, mask
642
+ return func
643
+
644
+ @staticmethod
645
+ def gen_shared_head(self):
646
+ def func(hidden_states):
647
+ """
648
+ Args:
649
+ hidden_states (Tensor): Hidden States of shape [B, hidden_units].
650
+
651
+ Returns:
652
+ logits (Tensor): Logits tensor of shape [B, C].
653
+ """
654
+ logits = self.classifier(hidden_states)
655
+ return logits
656
+ return func
657
+
658
+ @staticmethod
659
+ def gen_forward(lambdas, loss_normalization=False):
660
+ def func(
661
+ self,
662
+ input_ids: Optional[torch.Tensor] = None,
663
+ attention_mask: Optional[torch.Tensor] = None,
664
+ token_type_ids: Optional[torch.Tensor] = None,
665
+ position_ids: Optional[torch.Tensor] = None,
666
+ head_mask: Optional[torch.Tensor] = None,
667
+ inputs_embeds: Optional[torch.Tensor] = None,
668
+ labels: Optional[torch.Tensor] = None,
669
+ output_attentions: Optional[bool] = None,
670
+ output_hidden_states: Optional[bool] = None,
671
+ return_dict: Optional[bool] = None,
672
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
673
+ r"""
674
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
675
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
676
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
677
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
678
+ """
679
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
680
+
681
+ outputs = self.bert(
682
+ input_ids,
683
+ attention_mask=attention_mask,
684
+ token_type_ids=token_type_ids,
685
+ position_ids=position_ids,
686
+ head_mask=head_mask,
687
+ inputs_embeds=inputs_embeds,
688
+ output_attentions=output_attentions,
689
+ output_hidden_states=output_hidden_states,
690
+ return_dict=return_dict,
691
+ )
692
+
693
+ pooled_output = outputs[1]
694
+
695
+ pooled_output = self.dropout(pooled_output)
696
+ logits = self.classifier(pooled_output)
697
+
698
+ loss = None
699
+ if labels is not None:
700
+ assert self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int) # TODO: remove this
701
+ if self.config.problem_type is None:
702
+ if self.num_labels == 1:
703
+ self.config.problem_type = "regression"
704
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
705
+ self.config.problem_type = "single_label_classification"
706
+ else:
707
+ self.config.problem_type = "multi_label_classification"
708
+
709
+ if self.config.problem_type == "regression":
710
+ if self.num_labels == 1:
711
+ loss = F.mse_loss(logits.squeeze(), labels.squeeze())
712
+ else:
713
+ loss = F.mse_loss(logits, labels)
714
+ elif self.config.problem_type == "single_label_classification":
715
+ shared_head = BertForSequenceClassificationConfig.gen_shared_head(self)
716
+ criterion = BertForSequenceClassificationConfig.gen_criterion()
717
+ loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, pooled_output, logits, labels, loss_normalization)
718
+ elif self.config.problem_type == "multi_label_classification":
719
+ loss = F.binary_cross_entropy_with_logits(logits, labels)
720
+ if not return_dict:
721
+ output = (logits,) + outputs[2:]
722
+ return ((loss,) + output) if loss is not None else output
723
+
724
+ return SequenceClassifierOutput(
725
+ loss=loss,
726
+ logits=logits,
727
+ hidden_states=outputs.hidden_states,
728
+ attentions=outputs.attentions,
729
+ )
730
+ return func
731
+
732
+
733
+ # Boberta for Text Classification
734
+ class RobertaForSequenceClassificationConfig(BertForSequenceClassificationConfig):
735
+ @staticmethod
736
+ def gen_shared_head(self):
737
+ def func(hidden_states):
738
+ """
739
+ Args:
740
+ hidden_states (Tensor): Hidden States of shape [B, hidden_units].
741
+
742
+ Returns:
743
+ logits (Tensor): Logits tensor of shape [B, C].
744
+ """
745
+ logits = self.classifier(hidden_states)
746
+ return logits
747
+ return func
748
+
749
+ @staticmethod
750
+ def gen_forward(lambdas, loss_normalization=False):
751
+ def func(
752
+ self,
753
+ input_ids: Optional[torch.LongTensor] = None,
754
+ attention_mask: Optional[torch.FloatTensor] = None,
755
+ token_type_ids: Optional[torch.LongTensor] = None,
756
+ position_ids: Optional[torch.LongTensor] = None,
757
+ head_mask: Optional[torch.FloatTensor] = None,
758
+ inputs_embeds: Optional[torch.FloatTensor] = None,
759
+ labels: Optional[torch.LongTensor] = None,
760
+ output_attentions: Optional[bool] = None,
761
+ output_hidden_states: Optional[bool] = None,
762
+ return_dict: Optional[bool] = None,
763
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
764
+ r"""
765
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
766
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
767
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
768
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
769
+ """
770
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
771
+
772
+ outputs = self.roberta(
773
+ input_ids,
774
+ attention_mask=attention_mask,
775
+ token_type_ids=token_type_ids,
776
+ position_ids=position_ids,
777
+ head_mask=head_mask,
778
+ inputs_embeds=inputs_embeds,
779
+ output_attentions=output_attentions,
780
+ output_hidden_states=output_hidden_states,
781
+ return_dict=return_dict,
782
+ )
783
+ sequence_output = outputs[0]
784
+ logits = self.classifier(sequence_output)
785
+
786
+ loss = None
787
+ if labels is not None:
788
+ assert self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int) # TODO: remove this
789
+ # move labels to correct device to enable model parallelism
790
+ labels = labels.to(logits.device)
791
+ if self.config.problem_type is None:
792
+ if self.num_labels == 1:
793
+ self.config.problem_type = "regression"
794
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
795
+ self.config.problem_type = "single_label_classification"
796
+ else:
797
+ self.config.problem_type = "multi_label_classification"
798
+
799
+ if self.config.problem_type == "regression":
800
+ if self.num_labels == 1:
801
+ loss = F.mse_loss(logits.squeeze(), labels.squeeze())
802
+ else:
803
+ loss = F.mse_loss(logits, labels)
804
+ elif self.config.problem_type == "single_label_classification":
805
+ shared_head = BertForSequenceClassificationConfig.gen_shared_head(self)
806
+ criterion = BertForSequenceClassificationConfig.gen_criterion()
807
+ loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, sequence_output, logits, labels, loss_normalization)
808
+ elif self.config.problem_type == "multi_label_classification":
809
+ loss = F.binary_cross_entropy_with_logits(logits, labels)
810
+
811
+ if not return_dict:
812
+ output = (logits,) + outputs[2:]
813
+ return ((loss,) + output) if loss is not None else output
814
+
815
+ return SequenceClassifierOutput(
816
+ loss=loss,
817
+ logits=logits,
818
+ hidden_states=outputs.hidden_states,
819
+ attentions=outputs.attentions,
820
+ )
821
+ return func
822
+
823
+
824
+ # Wav2Vec2 for Speech Recognition
825
+ class Wav2Vec2ForCTCConfig(Config):
826
+ _HIDDEN_STATES_START_POSITION = 2
827
+
828
+ @staticmethod
829
+ def greedy_decode_ctc(
830
+ log_probs: torch.Tensor,
831
+ input_lengths: torch.Tensor,
832
+ blank_token_id: int,
833
+ target_lengths: torch.Tensor
834
+ ):
835
+ """
836
+ Convert logits to flattened predictions that match the shape of flattened_targets.
837
+
838
+ Args:
839
+ log_probs: [B, L, V] - log-softmax output
840
+ input_lengths: [B] - actual length of each input
841
+ blank_token_id: int - index of blank token
842
+ target_lengths: [B] - used to determine how many predictions to keep per sample
843
+
844
+ Returns:
845
+ flattened_predictions: 1D tensor, same total length as sum(target_lengths)
846
+ """
847
+ batch_size = log_probs.size(0)
848
+ decoded_all = []
849
+
850
+ predicted_ids = log_probs.argmax(dim=-1) # [B, L]
851
+
852
+ for i in range(batch_size):
853
+ pred = predicted_ids[i][:input_lengths[i]] # [Li]
854
+ prev = None
855
+ decoded = []
856
+ for token in pred:
857
+ token = token.item()
858
+ if token != blank_token_id and token != prev:
859
+ decoded.append(token)
860
+ prev = token
861
+ # Trim or pad to match target_lengths[i]
862
+ tgt_len = target_lengths[i].item()
863
+ if len(decoded) >= tgt_len:
864
+ decoded = decoded[:tgt_len]
865
+ else:
866
+ decoded = decoded + [blank_token_id] * (tgt_len - len(decoded)) # pad with blank
867
+ decoded_all.extend(decoded)
868
+
869
+ return torch.tensor(decoded_all, dtype=torch.long, device=log_probs.device) # shape: [sum(target_lengths)]
870
+
871
+ @staticmethod
872
+ def gen_criterion(input_lengths: Tensor, pad_token_id: int, ctc_zero_infinity: bool):
873
+ def func(logits: Tensor, labels: Tensor, mask=None):
874
+ """
875
+ Args:
876
+ logits (Tensor): Log Probablities of shape [B, L, V].
877
+ labels (Tensor): Flattened Targets of shape [B, L'].
878
+
879
+ Returns:
880
+ loss (Tensor): Scalar tensor representing the loss.
881
+ mask (Tensor): Boolean mask tensor of shape [B].
882
+ """
883
+ if mask is None:
884
+ mask = torch.ones_like(input_lengths, dtype=torch.float32, device=input_lengths.device)
885
+
886
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
887
+ labels_mask = labels >= 0
888
+ target_lengths = labels_mask.sum(-1)
889
+ flattened_targets = labels.masked_select(labels_mask)
890
+ with torch.backends.cudnn.flags(enabled=False):
891
+ masked_losses = nn.functional.ctc_loss(log_probs, flattened_targets, input_lengths, target_lengths, blank=pad_token_id, reduction="none", zero_infinity=ctc_zero_infinity)
892
+ loss = torch.sum(mask * masked_losses) / (torch.sum(mask) + 1e-6)
893
+
894
+ with torch.no_grad():
895
+ thres = 0.5
896
+ flattened_predictions = Wav2Vec2ForCTCConfig.greedy_decode_ctc(
897
+ log_probs.transpose(0, 1), # [B, T, V]
898
+ input_lengths=input_lengths,
899
+ blank_token_id=pad_token_id,
900
+ target_lengths=target_lengths
901
+ )
902
+ token_wise_mask = torch.eq(flattened_predictions, flattened_targets).to(flattened_targets.dtype)
903
+ segment_ids = torch.arange(len(target_lengths), device=target_lengths.device).repeat_interleave(target_lengths)
904
+ sequence_wise_mask = torch.zeros(len(target_lengths), dtype=target_lengths.dtype, device=token_wise_mask.device).scatter_add(0, segment_ids, token_wise_mask)
905
+ mask = mask * torch.ge(sequence_wise_mask, thres * target_lengths).to(flattened_targets.dtype)
906
+
907
+ return loss, mask
908
+ return func
909
+
910
+ @staticmethod
911
+ def gen_shared_head(self):
912
+ def func(hidden_states):
913
+ """
914
+ Args:
915
+ hidden_states (Tensor): Hidden States of shape [B, C, hidden_units].
916
+
917
+ Returns:
918
+ logits (Tensor): Logits tensor of shape [B, C, 2].
919
+ """
920
+ logits = self.lm_head(hidden_states)
921
+ # log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
922
+ return logits
923
+ return func
924
+
925
+ @staticmethod
926
+ def gen_forward(lambdas, loss_normalization=False):
927
+ def func(
928
+ self,
929
+ input_values: Optional[torch.Tensor],
930
+ attention_mask: Optional[torch.Tensor] = None,
931
+ output_attentions: Optional[bool] = None,
932
+ output_hidden_states: Optional[bool] = None,
933
+ return_dict: Optional[bool] = None,
934
+ labels: Optional[torch.Tensor] = None,
935
+ ) -> Union[Tuple, CausalLMOutput]:
936
+ r"""
937
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
938
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
939
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
940
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
941
+ config.vocab_size - 1]`.
942
+ """
943
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
944
+
945
+ if labels is not None and labels.max() >= self.config.vocab_size:
946
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
947
+
948
+ outputs = self.wav2vec2(
949
+ input_values,
950
+ attention_mask=attention_mask,
951
+ output_attentions=output_attentions,
952
+ output_hidden_states=output_hidden_states,
953
+ return_dict=return_dict,
954
+ )
955
+
956
+ hidden_states = outputs[0]
957
+ hidden_states = self.dropout(hidden_states)
958
+
959
+ logits = self.lm_head(hidden_states)
960
+
961
+ loss = None
962
+ if labels is not None:
963
+ # retrieve loss input_lengths from attention_mask
964
+ attention_mask = (
965
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
966
+ )
967
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
968
+ shared_head = Wav2Vec2ForCTCConfig.gen_shared_head(self)
969
+ criterion = Wav2Vec2ForCTCConfig.gen_criterion(input_lengths, self.config.pad_token_id, self.config.ctc_zero_infinity)
970
+ loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_states, logits, labels, loss_normalization) # NOTE: Apply TRP!
971
+
972
+ if not return_dict:
973
+ output = (logits,) + outputs[Wav2Vec2ForCTCConfig._HIDDEN_STATES_START_POSITION:]
974
+ return ((loss,) + output) if loss is not None else output
975
+
976
+ return CausalLMOutput(
977
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
978
+ )
979
+ return func
980
+
981
+
982
+ # MBart for Translation
983
+ class MBartForConditionalGenerationConfig(Config):
984
+ @staticmethod
985
+ def gen_criterion(vocab_size: int, top_k=1):
986
+ def func(logits, labels, mask=None):
987
+ """
988
+ Args:
989
+ logits (Tensor): Logits tensor of shape [B, L, V].
990
+ labels (Tensor): Target labels of shape [B, L].
991
+
992
+ Returns:
993
+ loss (Tensor): Scalar tensor representing the loss.
994
+ mask (Tensor): Boolean mask tensor of shape [B].
995
+ """
996
+ if mask is None:
997
+ mask = torch.ones_like(labels.view(-1), dtype=torch.float32, device=labels.device)
998
+
999
+ masked_losses = F.cross_entropy(logits.view(-1, vocab_size), labels.view(-1), reduction="none")
1000
+ loss = torch.sum(mask * masked_losses) / (torch.sum(mask) + 1e-6)
1001
+
1002
+ with torch.no_grad():
1003
+ topk_values, topk_indices = torch.topk(logits.view(-1, vocab_size), top_k, dim=1)
1004
+ mask = mask * torch.eq(topk_indices, labels.view(-1, 1)).any(dim=1).to(logits.dtype)
1005
+
1006
+ return loss, mask
1007
+ return func
1008
+
1009
+ @staticmethod
1010
+ def gen_shared_head(self):
1011
+ def func(hidden_states):
1012
+ """
1013
+ Args:
1014
+ hidden_states (Tensor): Hidden States of shape [B, L, hidden_units].
1015
+
1016
+ Returns:
1017
+ logits (Tensor): Logits tensor of shape [B, L].
1018
+ """
1019
+ logits = self.lm_head(hidden_states) + self.final_logits_bias
1020
+ return logits
1021
+ return func
1022
+
1023
+ @staticmethod
1024
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
1025
+ """
1026
+ Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
1027
+ have a single `decoder_start_token_id` in contrast to other Bart-like models.
1028
+ """
1029
+ prev_output_tokens = input_ids.clone()
1030
+
1031
+ if pad_token_id is None:
1032
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
1033
+ # replace possible -100 values in labels by `pad_token_id`
1034
+ prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)
1035
+
1036
+ index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
1037
+ decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
1038
+ prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
1039
+ prev_output_tokens[:, 0] = decoder_start_tokens
1040
+
1041
+ return prev_output_tokens
1042
+
1043
+ @staticmethod
1044
+ def gen_forward(lambdas, loss_normalization=False):
1045
+ def func(
1046
+ self,
1047
+ input_ids: torch.LongTensor = None,
1048
+ attention_mask: Optional[torch.Tensor] = None,
1049
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1050
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1051
+ head_mask: Optional[torch.Tensor] = None,
1052
+ decoder_head_mask: Optional[torch.Tensor] = None,
1053
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1054
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1055
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1056
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1057
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1058
+ labels: Optional[torch.LongTensor] = None,
1059
+ use_cache: Optional[bool] = None,
1060
+ output_attentions: Optional[bool] = None,
1061
+ output_hidden_states: Optional[bool] = None,
1062
+ return_dict: Optional[bool] = None,
1063
+ ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
1064
+ r"""
1065
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1066
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1067
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1068
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1069
+
1070
+ Returns:
1071
+
1072
+ """
1073
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1074
+
1075
+ if labels is not None:
1076
+ # if use_cache:
1077
+ # logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
1078
+ use_cache = False
1079
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
1080
+ decoder_input_ids = MBartForConditionalGenerationConfig.shift_tokens_right(labels, self.config.pad_token_id)
1081
+
1082
+ outputs = self.model(
1083
+ input_ids,
1084
+ attention_mask=attention_mask,
1085
+ decoder_input_ids=decoder_input_ids,
1086
+ encoder_outputs=encoder_outputs,
1087
+ decoder_attention_mask=decoder_attention_mask,
1088
+ head_mask=head_mask,
1089
+ decoder_head_mask=decoder_head_mask,
1090
+ cross_attn_head_mask=cross_attn_head_mask,
1091
+ past_key_values=past_key_values,
1092
+ inputs_embeds=inputs_embeds,
1093
+ decoder_inputs_embeds=decoder_inputs_embeds,
1094
+ use_cache=use_cache,
1095
+ output_attentions=output_attentions,
1096
+ output_hidden_states=output_hidden_states,
1097
+ return_dict=return_dict,
1098
+ )
1099
+ lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
1100
+
1101
+ masked_lm_loss = None
1102
+ if labels is not None:
1103
+ shared_head = MBartForConditionalGenerationConfig.gen_shared_head(self)
1104
+ criterion = MBartForConditionalGenerationConfig.gen_criterion(self.config.vocab_size)
1105
+ masked_lm_loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, outputs[0], lm_logits, labels, loss_normalization)
1106
+
1107
+ if not return_dict:
1108
+ output = (lm_logits,) + outputs[1:]
1109
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1110
+
1111
+ return Seq2SeqLMOutput(
1112
+ loss=masked_lm_loss,
1113
+ logits=lm_logits,
1114
+ past_key_values=outputs.past_key_values,
1115
+ decoder_hidden_states=outputs.decoder_hidden_states,
1116
+ decoder_attentions=outputs.decoder_attentions,
1117
+ cross_attentions=outputs.cross_attentions,
1118
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1119
+ encoder_hidden_states=outputs.encoder_hidden_states,
1120
+ encoder_attentions=outputs.encoder_attentions,
1121
+ )
1122
+ return func
1123
+
1124
+
1125
+ def apply_trp(model, depths: int, p: float, lambdas: List[float], **kwargs):
1126
+ if isinstance(model, transformers.Wav2Vec2ForSequenceClassification):
1127
+ print("✅ Applying TRP to Wav2Vec2 for Audio Classification...")
1128
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
1129
+ model.forward = types.MethodType(Wav2Vec2ForSequenceClassificationConfig.gen_forward(lambdas, False), model)
1130
+ elif isinstance(model, MobileNetV2):
1131
+ print("✅ Applying TRP to MobileNetV2 for Image Classification...")
1132
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1280, p) for _ in lambdas])
1133
+ model.forward = types.MethodType(MobileNetV2Config.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
1134
+ elif isinstance(model, ResNet):
1135
+ print("✅ Applying TRP to ResNet for Image Classification...")
1136
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 2048, p) for _ in lambdas])
1137
+ model.forward = types.MethodType(ResNetConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
1138
+ elif isinstance(model, EfficientNet):
1139
+ print("✅ Applying TRP to EfficientNet for Image Classification...")
1140
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1280, p) for _ in lambdas])
1141
+ model.forward = types.MethodType(EfficientNetConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
1142
+ elif isinstance(model, VisionTransformer):
1143
+ print("✅ Applying TRP to VisionTransformer for Image Classification...")
1144
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
1145
+ model.forward = types.MethodType(VisionTransformerConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
1146
+ elif isinstance(model, transformers.BertForQuestionAnswering):
1147
+ print("✅ Applying TRP to Bert for Question Answering...")
1148
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
1149
+ model.forward = types.MethodType(BertForQuestionAnsweringConfig.gen_forward(lambdas, True, 1), model)
1150
+ elif isinstance(model, FCN):
1151
+ print("✅ Applying TRP to FCN for Semantic Segmentation...")
1152
+ model.out_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 2048, p, dim=1) for _ in lambdas])
1153
+ model.aux_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p, dim=1) for _ in lambdas])
1154
+ model.forward = types.MethodType(FCNConfig.gen_forward(lambdas, True, 1), model)
1155
+ elif isinstance(model, DeepLabV3):
1156
+ print("✅ Applying TRP to DeepLabV3 for Semantic Segmentation...")
1157
+ model.out_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 2048, p, dim=1) for _ in lambdas])
1158
+ model.aux_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p, dim=1) for _ in lambdas])
1159
+ model.forward = types.MethodType(DeepLabV3Config.gen_forward(lambdas, True, 1), model)
1160
+ elif isinstance(model, transformers.BertForSequenceClassification):
1161
+ print("✅ Applying TRP to Bert for Text Classification...")
1162
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
1163
+ model.forward = types.MethodType(BertForSequenceClassificationConfig.gen_forward(lambdas, False), model)
1164
+ elif isinstance(model, transformers.RobertaForSequenceClassification):
1165
+ print("✅ Applying TRP to Roberta for Text Classification...")
1166
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
1167
+ model.forward = types.MethodType(RobertaForSequenceClassificationConfig.gen_forward(lambdas, False), model)
1168
+ elif isinstance(model, transformers.Wav2Vec2ForCTC):
1169
+ print("✅ Applying TRP to Wav2Vec2 for Speech Recognition...")
1170
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p) for _ in lambdas])
1171
+ model.forward = types.MethodType(Wav2Vec2ForCTCConfig.gen_forward(lambdas, False), model)
1172
+ elif isinstance(model, transformers.MBartForConditionalGeneration):
1173
+ print("✅ Applying TRP to MBart for Translation...")
1174
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p) for _ in lambdas])
1175
+ model.forward = types.MethodType(MBartForConditionalGenerationConfig.gen_forward(lambdas, False), model)
1176
+ else:
1177
+ torch._assert(
1178
+ isinstance(model, transformers.Wav2Vec2ForSequenceClassification),
1179
+ "The model should be an object of [`Wav2Vec2ForSequenceClassification`].")
1180
+
1181
+ return model
hpo-examples/image-classification/utils.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import datetime
3
+ import errno
4
+ import hashlib
5
+ import os
6
+ import time
7
+ from collections import defaultdict, deque, OrderedDict
8
+ from typing import List, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+
13
+
14
+ class SmoothedValue:
15
+ """Track a series of values and provide access to smoothed values over a
16
+ window or the global series average.
17
+ """
18
+
19
+ def __init__(self, window_size=20, fmt=None):
20
+ if fmt is None:
21
+ fmt = "{median:.4f} ({global_avg:.4f})"
22
+ self.deque = deque(maxlen=window_size)
23
+ self.total = 0.0
24
+ self.count = 0
25
+ self.fmt = fmt
26
+
27
+ def update(self, value, n=1):
28
+ self.deque.append(value)
29
+ self.count += n
30
+ self.total += value * n
31
+
32
+ def synchronize_between_processes(self):
33
+ """
34
+ Warning: does not synchronize the deque!
35
+ """
36
+ t = reduce_across_processes([self.count, self.total])
37
+ t = t.tolist()
38
+ self.count = int(t[0])
39
+ self.total = t[1]
40
+
41
+ @property
42
+ def median(self):
43
+ d = torch.tensor(list(self.deque))
44
+ return d.median().item()
45
+
46
+ @property
47
+ def avg(self):
48
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
49
+ return d.mean().item()
50
+
51
+ @property
52
+ def global_avg(self):
53
+ return self.total / self.count
54
+
55
+ @property
56
+ def max(self):
57
+ return max(self.deque)
58
+
59
+ @property
60
+ def value(self):
61
+ return self.deque[-1]
62
+
63
+ def __str__(self):
64
+ return self.fmt.format(
65
+ median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
66
+ )
67
+
68
+
69
+ class MetricLogger:
70
+ def __init__(self, delimiter="\t"):
71
+ self.meters = defaultdict(SmoothedValue)
72
+ self.delimiter = delimiter
73
+
74
+ def update(self, **kwargs):
75
+ for k, v in kwargs.items():
76
+ if isinstance(v, torch.Tensor):
77
+ v = v.item()
78
+ assert isinstance(v, (float, int))
79
+ self.meters[k].update(v)
80
+
81
+ def __getattr__(self, attr):
82
+ if attr in self.meters:
83
+ return self.meters[attr]
84
+ if attr in self.__dict__:
85
+ return self.__dict__[attr]
86
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
87
+
88
+ def __str__(self):
89
+ loss_str = []
90
+ for name, meter in self.meters.items():
91
+ loss_str.append(f"{name}: {str(meter)}")
92
+ return self.delimiter.join(loss_str)
93
+
94
+ def synchronize_between_processes(self):
95
+ for meter in self.meters.values():
96
+ meter.synchronize_between_processes()
97
+
98
+ def add_meter(self, name, meter):
99
+ self.meters[name] = meter
100
+
101
+ def log_every(self, iterable, print_freq, header=None):
102
+ i = 0
103
+ if not header:
104
+ header = ""
105
+ start_time = time.time()
106
+ end = time.time()
107
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
108
+ data_time = SmoothedValue(fmt="{avg:.4f}")
109
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
110
+ if torch.cuda.is_available():
111
+ log_msg = self.delimiter.join(
112
+ [
113
+ header,
114
+ "[{0" + space_fmt + "}/{1}]",
115
+ "eta: {eta}",
116
+ "{meters}",
117
+ "time: {time}",
118
+ "data: {data}",
119
+ "max mem: {memory:.0f}",
120
+ ]
121
+ )
122
+ else:
123
+ log_msg = self.delimiter.join(
124
+ [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
125
+ )
126
+ MB = 1024.0 * 1024.0
127
+ for obj in iterable:
128
+ data_time.update(time.time() - end)
129
+ yield obj
130
+ iter_time.update(time.time() - end)
131
+ if i % print_freq == 0:
132
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
133
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
134
+ if torch.cuda.is_available():
135
+ print(
136
+ log_msg.format(
137
+ i,
138
+ len(iterable),
139
+ eta=eta_string,
140
+ meters=str(self),
141
+ time=str(iter_time),
142
+ data=str(data_time),
143
+ memory=torch.cuda.max_memory_allocated() / MB,
144
+ )
145
+ )
146
+ else:
147
+ print(
148
+ log_msg.format(
149
+ i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
150
+ )
151
+ )
152
+ i += 1
153
+ end = time.time()
154
+ total_time = time.time() - start_time
155
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
156
+ print(f"{header} Total time: {total_time_str}")
157
+
158
+
159
+ class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
160
+ """Maintains moving averages of model parameters using an exponential decay.
161
+ ``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
162
+ `torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
163
+ is used to compute the EMA.
164
+ """
165
+
166
+ def __init__(self, model, decay, device="cpu"):
167
+ def ema_avg(avg_model_param, model_param, num_averaged):
168
+ return decay * avg_model_param + (1 - decay) * model_param
169
+
170
+ super().__init__(model, device, ema_avg, use_buffers=True)
171
+
172
+
173
+ def accuracy(output, target, topk=(1,)):
174
+ """Computes the accuracy over the k top predictions for the specified values of k"""
175
+ with torch.inference_mode():
176
+ maxk = max(topk)
177
+ batch_size = target.size(0)
178
+ if target.ndim == 2:
179
+ target = target.max(dim=1)[1]
180
+
181
+ _, pred = output.topk(maxk, 1, True, True)
182
+ pred = pred.t()
183
+ correct = pred.eq(target[None])
184
+
185
+ res = []
186
+ for k in topk:
187
+ correct_k = correct[:k].flatten().sum(dtype=torch.float32)
188
+ res.append(correct_k * (100.0 / batch_size))
189
+ return res
190
+
191
+
192
+ def mkdir(path):
193
+ try:
194
+ os.makedirs(path)
195
+ except OSError as e:
196
+ if e.errno != errno.EEXIST:
197
+ raise
198
+
199
+
200
+ def setup_for_distributed(is_master):
201
+ """
202
+ This function disables printing when not in master process
203
+ """
204
+ import builtins as __builtin__
205
+
206
+ builtin_print = __builtin__.print
207
+
208
+ def print(*args, **kwargs):
209
+ force = kwargs.pop("force", False)
210
+ if is_master or force:
211
+ builtin_print(*args, **kwargs)
212
+
213
+ __builtin__.print = print
214
+
215
+
216
+ def is_dist_avail_and_initialized():
217
+ if not dist.is_available():
218
+ return False
219
+ if not dist.is_initialized():
220
+ return False
221
+ return True
222
+
223
+
224
+ def get_world_size():
225
+ if not is_dist_avail_and_initialized():
226
+ return 1
227
+ return dist.get_world_size()
228
+
229
+
230
+ def get_rank():
231
+ if not is_dist_avail_and_initialized():
232
+ return 0
233
+ return dist.get_rank()
234
+
235
+
236
+ def is_main_process():
237
+ return get_rank() == 0
238
+
239
+
240
+ def save_on_master(*args, **kwargs):
241
+ if is_main_process():
242
+ torch.save(*args, **kwargs)
243
+
244
+
245
+ def init_distributed_mode(args):
246
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
247
+ args.rank = int(os.environ["RANK"])
248
+ args.world_size = int(os.environ["WORLD_SIZE"])
249
+ args.gpu = int(os.environ["LOCAL_RANK"])
250
+ elif "SLURM_PROCID" in os.environ:
251
+ args.rank = int(os.environ["SLURM_PROCID"])
252
+ args.gpu = args.rank % torch.cuda.device_count()
253
+ elif hasattr(args, "rank"):
254
+ pass
255
+ else:
256
+ print("Not using distributed mode")
257
+ args.distributed = False
258
+ return
259
+
260
+ args.distributed = True
261
+
262
+ torch.cuda.set_device(args.gpu)
263
+ args.dist_backend = "nccl"
264
+ print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
265
+ torch.distributed.init_process_group(
266
+ backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
267
+ )
268
+ torch.distributed.barrier()
269
+ setup_for_distributed(args.rank == 0)
270
+
271
+
272
+ def average_checkpoints(inputs):
273
+ """Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from:
274
+ https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16
275
+
276
+ Args:
277
+ inputs (List[str]): An iterable of string paths of checkpoints to load from.
278
+ Returns:
279
+ A dict of string keys mapping to various values. The 'model' key
280
+ from the returned dict should correspond to an OrderedDict mapping
281
+ string parameter names to torch Tensors.
282
+ """
283
+ params_dict = OrderedDict()
284
+ params_keys = None
285
+ new_state = None
286
+ num_models = len(inputs)
287
+ for fpath in inputs:
288
+ with open(fpath, "rb") as f:
289
+ state = torch.load(
290
+ f,
291
+ map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")),
292
+ )
293
+ # Copies over the settings from the first checkpoint
294
+ if new_state is None:
295
+ new_state = state
296
+ model_params = state["model"]
297
+ model_params_keys = list(model_params.keys())
298
+ if params_keys is None:
299
+ params_keys = model_params_keys
300
+ elif params_keys != model_params_keys:
301
+ raise KeyError(
302
+ f"For checkpoint {f}, expected list of params: {params_keys}, but found: {model_params_keys}"
303
+ )
304
+ for k in params_keys:
305
+ p = model_params[k]
306
+ if isinstance(p, torch.HalfTensor):
307
+ p = p.float()
308
+ if k not in params_dict:
309
+ params_dict[k] = p.clone()
310
+ # NOTE: clone() is needed in case of p is a shared parameter
311
+ else:
312
+ params_dict[k] += p
313
+ averaged_params = OrderedDict()
314
+ for k, v in params_dict.items():
315
+ averaged_params[k] = v
316
+ if averaged_params[k].is_floating_point():
317
+ averaged_params[k].div_(num_models)
318
+ else:
319
+ averaged_params[k] //= num_models
320
+ new_state["model"] = averaged_params
321
+ return new_state
322
+
323
+
324
+ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True):
325
+ """
326
+ This method can be used to prepare weights files for new models. It receives as
327
+ input a model architecture and a checkpoint from the training script and produces
328
+ a file with the weights ready for release.
329
+
330
+ Examples:
331
+ from torchvision import models as M
332
+
333
+ # Classification
334
+ model = M.mobilenet_v3_large(weights=None)
335
+ print(store_model_weights(model, './class.pth'))
336
+
337
+ # Quantized Classification
338
+ model = M.quantization.mobilenet_v3_large(weights=None, quantize=False)
339
+ model.fuse_model(is_qat=True)
340
+ model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
341
+ _ = torch.ao.quantization.prepare_qat(model, inplace=True)
342
+ print(store_model_weights(model, './qat.pth'))
343
+
344
+ # Object Detection
345
+ model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None)
346
+ print(store_model_weights(model, './obj.pth'))
347
+
348
+ # Segmentation
349
+ model = M.segmentation.deeplabv3_mobilenet_v3_large(weights=None, weights_backbone=None, aux_loss=True)
350
+ print(store_model_weights(model, './segm.pth', strict=False))
351
+
352
+ Args:
353
+ model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes.
354
+ checkpoint_path (str): The path of the checkpoint we will load.
355
+ checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored.
356
+ Default: "model".
357
+ strict (bool): whether to strictly enforce that the keys
358
+ in :attr:`state_dict` match the keys returned by this module's
359
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
360
+
361
+ Returns:
362
+ output_path (str): The location where the weights are saved.
363
+ """
364
+ # Store the new model next to the checkpoint_path
365
+ checkpoint_path = os.path.abspath(checkpoint_path)
366
+ output_dir = os.path.dirname(checkpoint_path)
367
+
368
+ # Deep copy to avoid side-effects on the model object.
369
+ model = copy.deepcopy(model)
370
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
371
+
372
+ # Load the weights to the model to validate that everything works
373
+ # and remove unnecessary weights (such as auxiliaries, etc)
374
+ if checkpoint_key == "model_ema":
375
+ del checkpoint[checkpoint_key]["n_averaged"]
376
+ torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.")
377
+ model.load_state_dict(checkpoint[checkpoint_key], strict=strict)
378
+
379
+ tmp_path = os.path.join(output_dir, str(model.__hash__()))
380
+ torch.save(model.state_dict(), tmp_path)
381
+
382
+ sha256_hash = hashlib.sha256()
383
+ with open(tmp_path, "rb") as f:
384
+ # Read and update hash string value in blocks of 4K
385
+ for byte_block in iter(lambda: f.read(4096), b""):
386
+ sha256_hash.update(byte_block)
387
+ hh = sha256_hash.hexdigest()
388
+
389
+ output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth")
390
+ os.replace(tmp_path, output_path)
391
+
392
+ return output_path
393
+
394
+
395
+ def reduce_across_processes(val):
396
+ if not is_dist_avail_and_initialized():
397
+ # nothing to sync, but we still convert to tensor for consistency with the distributed case.
398
+ return torch.tensor(val)
399
+
400
+ t = torch.tensor(val, device="cuda")
401
+ dist.barrier()
402
+ dist.all_reduce(t)
403
+ return t
404
+
405
+
406
+ def set_weight_decay(
407
+ model: torch.nn.Module,
408
+ weight_decay: float,
409
+ norm_weight_decay: Optional[float] = None,
410
+ norm_classes: Optional[List[type]] = None,
411
+ custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None,
412
+ ):
413
+ if not norm_classes:
414
+ norm_classes = [
415
+ torch.nn.modules.batchnorm._BatchNorm,
416
+ torch.nn.LayerNorm,
417
+ torch.nn.GroupNorm,
418
+ torch.nn.modules.instancenorm._InstanceNorm,
419
+ torch.nn.LocalResponseNorm,
420
+ ]
421
+ norm_classes = tuple(norm_classes)
422
+
423
+ params = {
424
+ "other": [],
425
+ "norm": [],
426
+ }
427
+ params_weight_decay = {
428
+ "other": weight_decay,
429
+ "norm": norm_weight_decay,
430
+ }
431
+ custom_keys = []
432
+ if custom_keys_weight_decay is not None:
433
+ for key, weight_decay in custom_keys_weight_decay:
434
+ params[key] = []
435
+ params_weight_decay[key] = weight_decay
436
+ custom_keys.append(key)
437
+
438
+ def _add_params(module, prefix=""):
439
+ for name, p in module.named_parameters(recurse=False):
440
+ if not p.requires_grad:
441
+ continue
442
+ is_custom_key = False
443
+ for key in custom_keys:
444
+ target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name
445
+ if key == target_name:
446
+ params[key].append(p)
447
+ is_custom_key = True
448
+ break
449
+ if not is_custom_key:
450
+ if norm_weight_decay is not None and isinstance(module, norm_classes):
451
+ params["norm"].append(p)
452
+ else:
453
+ params["other"].append(p)
454
+
455
+ for child_name, child_module in module.named_children():
456
+ child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
457
+ _add_params(child_module, prefix=child_prefix)
458
+
459
+ _add_params(model)
460
+
461
+ param_groups = []
462
+ for key in params:
463
+ if len(params[key]) > 0:
464
+ param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
465
+ return param_groups
hpo-examples/image-classification/vit_b_16/model_4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29f1bd991a3f27f7982e8700f31332cb94c4783b83e240fc71f1eca03d4eb468
3
+ size 1053172110
hpo-examples/question-answering/qa/README.md ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ license: apache-2.0
4
+ base_model: google-bert/bert-base-uncased
5
+ tags:
6
+ - generated_from_trainer
7
+ datasets:
8
+ - squad
9
+ model-index:
10
+ - name: baseline
11
+ results: []
12
+ ---
13
+
14
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
15
+ should probably proofread and complete it, then remove this comment. -->
16
+
17
+ # baseline
18
+
19
+ This model is a fine-tuned version of [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) on the squad dataset.
20
+
21
+ ## Model description
22
+
23
+ More information needed
24
+
25
+ ## Intended uses & limitations
26
+
27
+ More information needed
28
+
29
+ ## Training and evaluation data
30
+
31
+ More information needed
32
+
33
+ ## Training procedure
34
+
35
+ ### Training hyperparameters
36
+
37
+ The following hyperparameters were used during training:
38
+ - learning_rate: 3e-05
39
+ - train_batch_size: 12
40
+ - eval_batch_size: 8
41
+ - seed: 42
42
+ - optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
43
+ - lr_scheduler_type: linear
44
+ - num_epochs: 2.0
45
+
46
+ ### Training results
47
+
48
+
49
+
50
+ ### Framework versions
51
+
52
+ - Transformers 4.49.0
53
+ - Pytorch 2.6.0+cu118
54
+ - Datasets 3.3.1
55
+ - Tokenizers 0.21.0
hpo-examples/question-answering/qa/all_results.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 2.0,
3
+ "eval_exact_match": 81.49479659413434,
4
+ "eval_f1": 88.62945564424126,
5
+ "eval_runtime": 61.0301,
6
+ "eval_samples": 10784,
7
+ "eval_samples_per_second": 176.7,
8
+ "eval_steps_per_second": 22.087,
9
+ "total_flos": 3.541929151120589e+16,
10
+ "train_loss": 1.148573803161563,
11
+ "train_runtime": 3245.3985,
12
+ "train_samples": 88524,
13
+ "train_samples_per_second": 54.554,
14
+ "train_steps_per_second": 4.546
15
+ }
hpo-examples/question-answering/qa/config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "google-bert/bert-base-uncased",
3
+ "architectures": [
4
+ "BertForQuestionAnswering"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "classifier_dropout": null,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-12,
15
+ "max_position_embeddings": 512,
16
+ "model_type": "bert",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 0,
20
+ "position_embedding_type": "absolute",
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.49.0",
23
+ "type_vocab_size": 2,
24
+ "use_cache": true,
25
+ "vocab_size": 30522
26
+ }
hpo-examples/question-answering/qa/eval_nbest_predictions.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b8d44953cbe0ce20d1d1b62b72e7adba18bf1dc81d055492e22bfa21ff46657
3
+ size 49596120
hpo-examples/question-answering/qa/eval_predictions.json ADDED
The diff for this file is too large to render. See raw diff
 
hpo-examples/question-answering/qa/eval_results.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 2.0,
3
+ "eval_exact_match": 81.49479659413434,
4
+ "eval_f1": 88.62945564424126,
5
+ "eval_runtime": 61.0301,
6
+ "eval_samples": 10784,
7
+ "eval_samples_per_second": 176.7,
8
+ "eval_steps_per_second": 22.087
9
+ }
hpo-examples/question-answering/qa/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38003bd65e4bfa70dd16886f29af7ab00d1aa0ae4de191b0a7de4d7883d17dde
3
+ size 442683784
hpo-examples/question-answering/qa/runs/May15_03-24-14_cs-Precision-7960-Tower/events.out.tfevents.1747293859.cs-Precision-7960-Tower.147971.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36bfca6273a2422943de7b634cf75efd69b8e92079abe84df9e9c9e026d497f6
3
+ size 11535
hpo-examples/question-answering/qa/runs/May15_03-24-14_cs-Precision-7960-Tower/events.out.tfevents.1747297197.cs-Precision-7960-Tower.147971.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:259c79a03ba9c522b1fd728e92dae5cfc31c6cd73b2377d124749c83a0163910
3
+ size 412
hpo-examples/question-answering/qa/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
hpo-examples/question-answering/qa/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
hpo-examples/question-answering/qa/tokenizer_config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": false,
45
+ "cls_token": "[CLS]",
46
+ "do_lower_case": true,
47
+ "extra_special_tokens": {},
48
+ "mask_token": "[MASK]",
49
+ "model_max_length": 512,
50
+ "pad_token": "[PAD]",
51
+ "sep_token": "[SEP]",
52
+ "strip_accents": null,
53
+ "tokenize_chinese_chars": true,
54
+ "tokenizer_class": "BertTokenizer",
55
+ "unk_token": "[UNK]"
56
+ }
hpo-examples/question-answering/qa/train_results.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 2.0,
3
+ "total_flos": 3.541929151120589e+16,
4
+ "train_loss": 1.148573803161563,
5
+ "train_runtime": 3245.3985,
6
+ "train_samples": 88524,
7
+ "train_samples_per_second": 54.554,
8
+ "train_steps_per_second": 4.546
9
+ }
hpo-examples/question-answering/qa/trainer_state.json ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": null,
3
+ "best_model_checkpoint": null,
4
+ "epoch": 2.0,
5
+ "eval_steps": 500,
6
+ "global_step": 14754,
7
+ "is_hyper_param_search": false,
8
+ "is_local_process_zero": true,
9
+ "is_world_process_zero": true,
10
+ "log_history": [
11
+ {
12
+ "epoch": 0.06777822963264199,
13
+ "grad_norm": 31.397275924682617,
14
+ "learning_rate": 2.8983326555510372e-05,
15
+ "loss": 2.7299,
16
+ "step": 500
17
+ },
18
+ {
19
+ "epoch": 0.13555645926528398,
20
+ "grad_norm": 25.8492431640625,
21
+ "learning_rate": 2.796665311102074e-05,
22
+ "loss": 1.752,
23
+ "step": 1000
24
+ },
25
+ {
26
+ "epoch": 0.203334688897926,
27
+ "grad_norm": 29.627431869506836,
28
+ "learning_rate": 2.694997966653111e-05,
29
+ "loss": 1.5588,
30
+ "step": 1500
31
+ },
32
+ {
33
+ "epoch": 0.27111291853056796,
34
+ "grad_norm": 21.147193908691406,
35
+ "learning_rate": 2.593330622204148e-05,
36
+ "loss": 1.5014,
37
+ "step": 2000
38
+ },
39
+ {
40
+ "epoch": 0.33889114816321,
41
+ "grad_norm": 17.81966781616211,
42
+ "learning_rate": 2.491663277755185e-05,
43
+ "loss": 1.4768,
44
+ "step": 2500
45
+ },
46
+ {
47
+ "epoch": 0.406669377795852,
48
+ "grad_norm": 20.26822853088379,
49
+ "learning_rate": 2.389995933306222e-05,
50
+ "loss": 1.4064,
51
+ "step": 3000
52
+ },
53
+ {
54
+ "epoch": 0.47444760742849396,
55
+ "grad_norm": 16.216028213500977,
56
+ "learning_rate": 2.288328588857259e-05,
57
+ "loss": 1.3502,
58
+ "step": 3500
59
+ },
60
+ {
61
+ "epoch": 0.5422258370611359,
62
+ "grad_norm": 17.930505752563477,
63
+ "learning_rate": 2.1866612444082963e-05,
64
+ "loss": 1.3101,
65
+ "step": 4000
66
+ },
67
+ {
68
+ "epoch": 0.6100040666937779,
69
+ "grad_norm": 26.499574661254883,
70
+ "learning_rate": 2.084993899959333e-05,
71
+ "loss": 1.2922,
72
+ "step": 4500
73
+ },
74
+ {
75
+ "epoch": 0.67778229632642,
76
+ "grad_norm": 26.83368492126465,
77
+ "learning_rate": 1.9833265555103702e-05,
78
+ "loss": 1.3053,
79
+ "step": 5000
80
+ },
81
+ {
82
+ "epoch": 0.745560525959062,
83
+ "grad_norm": 22.85872459411621,
84
+ "learning_rate": 1.8816592110614073e-05,
85
+ "loss": 1.2555,
86
+ "step": 5500
87
+ },
88
+ {
89
+ "epoch": 0.813338755591704,
90
+ "grad_norm": 23.48080825805664,
91
+ "learning_rate": 1.779991866612444e-05,
92
+ "loss": 1.2068,
93
+ "step": 6000
94
+ },
95
+ {
96
+ "epoch": 0.8811169852243459,
97
+ "grad_norm": 20.919252395629883,
98
+ "learning_rate": 1.6783245221634812e-05,
99
+ "loss": 1.1991,
100
+ "step": 6500
101
+ },
102
+ {
103
+ "epoch": 0.9488952148569879,
104
+ "grad_norm": 23.9005126953125,
105
+ "learning_rate": 1.576657177714518e-05,
106
+ "loss": 1.2156,
107
+ "step": 7000
108
+ },
109
+ {
110
+ "epoch": 1.01667344448963,
111
+ "grad_norm": 22.660743713378906,
112
+ "learning_rate": 1.4749898332655551e-05,
113
+ "loss": 1.0827,
114
+ "step": 7500
115
+ },
116
+ {
117
+ "epoch": 1.0844516741222718,
118
+ "grad_norm": 25.28419303894043,
119
+ "learning_rate": 1.373322488816592e-05,
120
+ "loss": 0.8481,
121
+ "step": 8000
122
+ },
123
+ {
124
+ "epoch": 1.152229903754914,
125
+ "grad_norm": 14.510698318481445,
126
+ "learning_rate": 1.271655144367629e-05,
127
+ "loss": 0.872,
128
+ "step": 8500
129
+ },
130
+ {
131
+ "epoch": 1.2200081333875559,
132
+ "grad_norm": 29.12289810180664,
133
+ "learning_rate": 1.1699877999186661e-05,
134
+ "loss": 0.8375,
135
+ "step": 9000
136
+ },
137
+ {
138
+ "epoch": 1.287786363020198,
139
+ "grad_norm": 19.038454055786133,
140
+ "learning_rate": 1.0683204554697033e-05,
141
+ "loss": 0.8464,
142
+ "step": 9500
143
+ },
144
+ {
145
+ "epoch": 1.35556459265284,
146
+ "grad_norm": 21.09101676940918,
147
+ "learning_rate": 9.666531110207402e-06,
148
+ "loss": 0.8746,
149
+ "step": 10000
150
+ },
151
+ {
152
+ "epoch": 1.4233428222854818,
153
+ "grad_norm": 20.79250144958496,
154
+ "learning_rate": 8.649857665717772e-06,
155
+ "loss": 0.8776,
156
+ "step": 10500
157
+ },
158
+ {
159
+ "epoch": 1.491121051918124,
160
+ "grad_norm": 21.217571258544922,
161
+ "learning_rate": 7.633184221228141e-06,
162
+ "loss": 0.8523,
163
+ "step": 11000
164
+ },
165
+ {
166
+ "epoch": 1.5588992815507658,
167
+ "grad_norm": 15.557079315185547,
168
+ "learning_rate": 6.616510776738511e-06,
169
+ "loss": 0.8387,
170
+ "step": 11500
171
+ },
172
+ {
173
+ "epoch": 1.626677511183408,
174
+ "grad_norm": 14.53345012664795,
175
+ "learning_rate": 5.5998373322488825e-06,
176
+ "loss": 0.8377,
177
+ "step": 12000
178
+ },
179
+ {
180
+ "epoch": 1.6944557408160499,
181
+ "grad_norm": 26.921611785888672,
182
+ "learning_rate": 4.583163887759252e-06,
183
+ "loss": 0.8449,
184
+ "step": 12500
185
+ },
186
+ {
187
+ "epoch": 1.7622339704486918,
188
+ "grad_norm": 12.789366722106934,
189
+ "learning_rate": 3.566490443269622e-06,
190
+ "loss": 0.8547,
191
+ "step": 13000
192
+ },
193
+ {
194
+ "epoch": 1.830012200081334,
195
+ "grad_norm": 37.19759750366211,
196
+ "learning_rate": 2.549816998779992e-06,
197
+ "loss": 0.818,
198
+ "step": 13500
199
+ },
200
+ {
201
+ "epoch": 1.8977904297139758,
202
+ "grad_norm": 14.62682819366455,
203
+ "learning_rate": 1.533143554290362e-06,
204
+ "loss": 0.8128,
205
+ "step": 14000
206
+ },
207
+ {
208
+ "epoch": 1.965568659346618,
209
+ "grad_norm": 21.051790237426758,
210
+ "learning_rate": 5.164701098007319e-07,
211
+ "loss": 0.8115,
212
+ "step": 14500
213
+ },
214
+ {
215
+ "epoch": 2.0,
216
+ "step": 14754,
217
+ "total_flos": 3.541929151120589e+16,
218
+ "train_loss": 1.148573803161563,
219
+ "train_runtime": 3245.3985,
220
+ "train_samples_per_second": 54.554,
221
+ "train_steps_per_second": 4.546
222
+ }
223
+ ],
224
+ "logging_steps": 500,
225
+ "max_steps": 14754,
226
+ "num_input_tokens_seen": 0,
227
+ "num_train_epochs": 2,
228
+ "save_steps": 500,
229
+ "stateful_callbacks": {
230
+ "TrainerControl": {
231
+ "args": {
232
+ "should_epoch_stop": false,
233
+ "should_evaluate": false,
234
+ "should_log": false,
235
+ "should_save": true,
236
+ "should_training_stop": true
237
+ },
238
+ "attributes": {}
239
+ }
240
+ },
241
+ "total_flos": 3.541929151120589e+16,
242
+ "train_batch_size": 12,
243
+ "trial_name": null,
244
+ "trial_params": null
245
+ }
hpo-examples/question-answering/qa/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe8e61ba1ca1cb106ca9adca5e9262fa9a262238814728a69256855c78c32f51
3
+ size 5304
hpo-examples/question-answering/qa/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
hpo-examples/question-answering/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ accelerate >= 0.12.0
2
+ datasets >= 1.8.0
3
+ torch >= 1.3.0
4
+ evaluate