Upload 178 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- hpo-examples/audio-classification/ac/README.md +88 -0
- hpo-examples/audio-classification/ac/all_results.json +13 -0
- hpo-examples/audio-classification/ac/config.json +147 -0
- hpo-examples/audio-classification/ac/eval_results.json +8 -0
- hpo-examples/audio-classification/ac/model.safetensors +3 -0
- hpo-examples/audio-classification/ac/preprocessor_config.json +9 -0
- hpo-examples/audio-classification/ac/runs/May15_03-06-03_cs-Precision-7960-Tower/events.out.tfevents.1747292768.cs-Precision-7960-Tower.146737.0 +3 -0
- hpo-examples/audio-classification/ac/runs/May15_03-06-03_cs-Precision-7960-Tower/events.out.tfevents.1747293535.cs-Precision-7960-Tower.146737.1 +3 -0
- hpo-examples/audio-classification/ac/train_results.json +8 -0
- hpo-examples/audio-classification/ac/trainer_state.json +1598 -0
- hpo-examples/audio-classification/ac/training_args.bin +3 -0
- hpo-examples/audio-classification/requirements.txt +5 -0
- hpo-examples/audio-classification/run.sh +30 -0
- hpo-examples/audio-classification/run_audio_classification.py +462 -0
- hpo-examples/audio-classification/trplib.py +1181 -0
- hpo-examples/image-classification/__pycache__/presets.cpython-310.pyc +0 -0
- hpo-examples/image-classification/__pycache__/sampler.cpython-310.pyc +0 -0
- hpo-examples/image-classification/__pycache__/transforms.cpython-310.pyc +0 -0
- hpo-examples/image-classification/__pycache__/trplib.cpython-310.pyc +0 -0
- hpo-examples/image-classification/__pycache__/utils.cpython-310.pyc +0 -0
- hpo-examples/image-classification/efficientnet_v2_m/model_7.pth +3 -0
- hpo-examples/image-classification/mobilenetv2/model_32.pth +3 -0
- hpo-examples/image-classification/presets.py +71 -0
- hpo-examples/image-classification/resnet50/model_35.pth +3 -0
- hpo-examples/image-classification/run.sh +49 -0
- hpo-examples/image-classification/sampler.py +62 -0
- hpo-examples/image-classification/train.py +524 -0
- hpo-examples/image-classification/train_quantization.py +265 -0
- hpo-examples/image-classification/transforms.py +183 -0
- hpo-examples/image-classification/trplib.py +1181 -0
- hpo-examples/image-classification/utils.py +465 -0
- hpo-examples/image-classification/vit_b_16/model_4.pth +3 -0
- hpo-examples/question-answering/qa/README.md +55 -0
- hpo-examples/question-answering/qa/all_results.json +15 -0
- hpo-examples/question-answering/qa/config.json +26 -0
- hpo-examples/question-answering/qa/eval_nbest_predictions.json +3 -0
- hpo-examples/question-answering/qa/eval_predictions.json +0 -0
- hpo-examples/question-answering/qa/eval_results.json +9 -0
- hpo-examples/question-answering/qa/model.safetensors +3 -0
- hpo-examples/question-answering/qa/runs/May15_03-24-14_cs-Precision-7960-Tower/events.out.tfevents.1747293859.cs-Precision-7960-Tower.147971.0 +3 -0
- hpo-examples/question-answering/qa/runs/May15_03-24-14_cs-Precision-7960-Tower/events.out.tfevents.1747297197.cs-Precision-7960-Tower.147971.1 +3 -0
- hpo-examples/question-answering/qa/special_tokens_map.json +7 -0
- hpo-examples/question-answering/qa/tokenizer.json +0 -0
- hpo-examples/question-answering/qa/tokenizer_config.json +56 -0
- hpo-examples/question-answering/qa/train_results.json +9 -0
- hpo-examples/question-answering/qa/trainer_state.json +245 -0
- hpo-examples/question-answering/qa/training_args.bin +3 -0
- hpo-examples/question-answering/qa/vocab.txt +0 -0
- hpo-examples/question-answering/requirements.txt +4 -0
.gitattributes
CHANGED
@@ -37,3 +37,4 @@ qa/eval_nbest_predictions.json filter=lfs diff=lfs merge=lfs -text
|
|
37 |
qa/sequential-policy-gradient.pdf filter=lfs diff=lfs merge=lfs -text
|
38 |
sequential-policy-gradient.png filter=lfs diff=lfs merge=lfs -text
|
39 |
examples/question-answering/qa/eval_nbest_predictions.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
37 |
qa/sequential-policy-gradient.pdf filter=lfs diff=lfs merge=lfs -text
|
38 |
sequential-policy-gradient.png filter=lfs diff=lfs merge=lfs -text
|
39 |
examples/question-answering/qa/eval_nbest_predictions.json filter=lfs diff=lfs merge=lfs -text
|
40 |
+
hpo-examples/question-answering/qa/eval_nbest_predictions.json filter=lfs diff=lfs merge=lfs -text
|
hpo-examples/audio-classification/ac/README.md
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: transformers
|
3 |
+
license: apache-2.0
|
4 |
+
base_model: facebook/wav2vec2-base
|
5 |
+
tags:
|
6 |
+
- audio-classification
|
7 |
+
- generated_from_trainer
|
8 |
+
datasets:
|
9 |
+
- superb
|
10 |
+
metrics:
|
11 |
+
- accuracy
|
12 |
+
model-index:
|
13 |
+
- name: wav2vec2-base-ft-keyword-spotting
|
14 |
+
results:
|
15 |
+
- task:
|
16 |
+
name: Audio Classification
|
17 |
+
type: audio-classification
|
18 |
+
dataset:
|
19 |
+
name: superb
|
20 |
+
type: superb
|
21 |
+
config: ks
|
22 |
+
split: validation
|
23 |
+
args: ks
|
24 |
+
metrics:
|
25 |
+
- name: Accuracy
|
26 |
+
type: accuracy
|
27 |
+
value: 0.9826419535157399
|
28 |
+
---
|
29 |
+
|
30 |
+
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
31 |
+
should probably proofread and complete it, then remove this comment. -->
|
32 |
+
|
33 |
+
# wav2vec2-base-ft-keyword-spotting
|
34 |
+
|
35 |
+
This model is a fine-tuned version of [facebook/wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) on the superb dataset.
|
36 |
+
It achieves the following results on the evaluation set:
|
37 |
+
- Loss: 0.0954
|
38 |
+
- Accuracy: 0.9826
|
39 |
+
|
40 |
+
## Model description
|
41 |
+
|
42 |
+
More information needed
|
43 |
+
|
44 |
+
## Intended uses & limitations
|
45 |
+
|
46 |
+
More information needed
|
47 |
+
|
48 |
+
## Training and evaluation data
|
49 |
+
|
50 |
+
More information needed
|
51 |
+
|
52 |
+
## Training procedure
|
53 |
+
|
54 |
+
### Training hyperparameters
|
55 |
+
|
56 |
+
The following hyperparameters were used during training:
|
57 |
+
- learning_rate: 3e-05
|
58 |
+
- train_batch_size: 48
|
59 |
+
- eval_batch_size: 32
|
60 |
+
- seed: 0
|
61 |
+
- gradient_accumulation_steps: 4
|
62 |
+
- total_train_batch_size: 192
|
63 |
+
- optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
|
64 |
+
- lr_scheduler_type: linear
|
65 |
+
- lr_scheduler_warmup_ratio: 0.1
|
66 |
+
- num_epochs: 8.0
|
67 |
+
- mixed_precision_training: Native AMP
|
68 |
+
|
69 |
+
### Training results
|
70 |
+
|
71 |
+
| Training Loss | Epoch | Step | Validation Loss | Accuracy |
|
72 |
+
|:-------------:|:------:|:----:|:---------------:|:--------:|
|
73 |
+
| 1.3624 | 1.0 | 267 | 1.1959 | 0.6546 |
|
74 |
+
| 0.3854 | 2.0 | 534 | 0.2675 | 0.9734 |
|
75 |
+
| 0.2473 | 3.0 | 801 | 0.1461 | 0.9768 |
|
76 |
+
| 0.1997 | 4.0 | 1068 | 0.1088 | 0.9804 |
|
77 |
+
| 0.1723 | 5.0 | 1335 | 0.0954 | 0.9826 |
|
78 |
+
| 0.1442 | 6.0 | 1602 | 0.0927 | 0.9813 |
|
79 |
+
| 0.1397 | 7.0 | 1869 | 0.0892 | 0.9812 |
|
80 |
+
| 0.1368 | 7.9728 | 2128 | 0.0896 | 0.9812 |
|
81 |
+
|
82 |
+
|
83 |
+
### Framework versions
|
84 |
+
|
85 |
+
- Transformers 4.49.0
|
86 |
+
- Pytorch 2.6.0+cu118
|
87 |
+
- Datasets 3.3.1
|
88 |
+
- Tokenizers 0.21.0
|
hpo-examples/audio-classification/ac/all_results.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"epoch": 7.972769953051643,
|
3 |
+
"eval_accuracy": 0.9826419535157399,
|
4 |
+
"eval_loss": 0.09542840719223022,
|
5 |
+
"eval_runtime": 5.5538,
|
6 |
+
"eval_samples_per_second": 1224.023,
|
7 |
+
"eval_steps_per_second": 38.352,
|
8 |
+
"total_flos": 3.767900833756416e+18,
|
9 |
+
"train_loss": 0.5178930132572812,
|
10 |
+
"train_runtime": 756.2923,
|
11 |
+
"train_samples_per_second": 540.468,
|
12 |
+
"train_steps_per_second": 2.814
|
13 |
+
}
|
hpo-examples/audio-classification/ac/config.json
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "facebook/wav2vec2-base",
|
3 |
+
"activation_dropout": 0.0,
|
4 |
+
"adapter_attn_dim": null,
|
5 |
+
"adapter_kernel_size": 3,
|
6 |
+
"adapter_stride": 2,
|
7 |
+
"add_adapter": false,
|
8 |
+
"apply_spec_augment": true,
|
9 |
+
"architectures": [
|
10 |
+
"Wav2Vec2ForSequenceClassification"
|
11 |
+
],
|
12 |
+
"attention_dropout": 0.1,
|
13 |
+
"bos_token_id": 1,
|
14 |
+
"classifier_proj_size": 256,
|
15 |
+
"codevector_dim": 256,
|
16 |
+
"contrastive_logits_temperature": 0.1,
|
17 |
+
"conv_bias": false,
|
18 |
+
"conv_dim": [
|
19 |
+
512,
|
20 |
+
512,
|
21 |
+
512,
|
22 |
+
512,
|
23 |
+
512,
|
24 |
+
512,
|
25 |
+
512
|
26 |
+
],
|
27 |
+
"conv_kernel": [
|
28 |
+
10,
|
29 |
+
3,
|
30 |
+
3,
|
31 |
+
3,
|
32 |
+
3,
|
33 |
+
2,
|
34 |
+
2
|
35 |
+
],
|
36 |
+
"conv_stride": [
|
37 |
+
5,
|
38 |
+
2,
|
39 |
+
2,
|
40 |
+
2,
|
41 |
+
2,
|
42 |
+
2,
|
43 |
+
2
|
44 |
+
],
|
45 |
+
"ctc_loss_reduction": "sum",
|
46 |
+
"ctc_zero_infinity": false,
|
47 |
+
"diversity_loss_weight": 0.1,
|
48 |
+
"do_stable_layer_norm": false,
|
49 |
+
"eos_token_id": 2,
|
50 |
+
"feat_extract_activation": "gelu",
|
51 |
+
"feat_extract_norm": "group",
|
52 |
+
"feat_proj_dropout": 0.1,
|
53 |
+
"feat_quantizer_dropout": 0.0,
|
54 |
+
"final_dropout": 0.0,
|
55 |
+
"finetuning_task": "audio-classification",
|
56 |
+
"freeze_feat_extract_train": true,
|
57 |
+
"hidden_act": "gelu",
|
58 |
+
"hidden_dropout": 0.1,
|
59 |
+
"hidden_size": 768,
|
60 |
+
"id2label": {
|
61 |
+
"0": "yes",
|
62 |
+
"1": "no",
|
63 |
+
"10": "_silence_",
|
64 |
+
"11": "_unknown_",
|
65 |
+
"2": "up",
|
66 |
+
"3": "down",
|
67 |
+
"4": "left",
|
68 |
+
"5": "right",
|
69 |
+
"6": "on",
|
70 |
+
"7": "off",
|
71 |
+
"8": "stop",
|
72 |
+
"9": "go"
|
73 |
+
},
|
74 |
+
"initializer_range": 0.02,
|
75 |
+
"intermediate_size": 3072,
|
76 |
+
"label2id": {
|
77 |
+
"_silence_": "10",
|
78 |
+
"_unknown_": "11",
|
79 |
+
"down": "3",
|
80 |
+
"go": "9",
|
81 |
+
"left": "4",
|
82 |
+
"no": "1",
|
83 |
+
"off": "7",
|
84 |
+
"on": "6",
|
85 |
+
"right": "5",
|
86 |
+
"stop": "8",
|
87 |
+
"up": "2",
|
88 |
+
"yes": "0"
|
89 |
+
},
|
90 |
+
"layer_norm_eps": 1e-05,
|
91 |
+
"layerdrop": 0.0,
|
92 |
+
"mask_channel_length": 10,
|
93 |
+
"mask_channel_min_space": 1,
|
94 |
+
"mask_channel_other": 0.0,
|
95 |
+
"mask_channel_prob": 0.0,
|
96 |
+
"mask_channel_selection": "static",
|
97 |
+
"mask_feature_length": 10,
|
98 |
+
"mask_feature_min_masks": 0,
|
99 |
+
"mask_feature_prob": 0.0,
|
100 |
+
"mask_time_length": 10,
|
101 |
+
"mask_time_min_masks": 2,
|
102 |
+
"mask_time_min_space": 1,
|
103 |
+
"mask_time_other": 0.0,
|
104 |
+
"mask_time_prob": 0.05,
|
105 |
+
"mask_time_selection": "static",
|
106 |
+
"model_type": "wav2vec2",
|
107 |
+
"no_mask_channel_overlap": false,
|
108 |
+
"no_mask_time_overlap": false,
|
109 |
+
"num_adapter_layers": 3,
|
110 |
+
"num_attention_heads": 12,
|
111 |
+
"num_codevector_groups": 2,
|
112 |
+
"num_codevectors_per_group": 320,
|
113 |
+
"num_conv_pos_embedding_groups": 16,
|
114 |
+
"num_conv_pos_embeddings": 128,
|
115 |
+
"num_feat_extract_layers": 7,
|
116 |
+
"num_hidden_layers": 12,
|
117 |
+
"num_negatives": 100,
|
118 |
+
"output_hidden_size": 768,
|
119 |
+
"pad_token_id": 0,
|
120 |
+
"proj_codevector_dim": 256,
|
121 |
+
"tdnn_dilation": [
|
122 |
+
1,
|
123 |
+
2,
|
124 |
+
3,
|
125 |
+
1,
|
126 |
+
1
|
127 |
+
],
|
128 |
+
"tdnn_dim": [
|
129 |
+
512,
|
130 |
+
512,
|
131 |
+
512,
|
132 |
+
512,
|
133 |
+
1500
|
134 |
+
],
|
135 |
+
"tdnn_kernel": [
|
136 |
+
5,
|
137 |
+
3,
|
138 |
+
3,
|
139 |
+
1,
|
140 |
+
1
|
141 |
+
],
|
142 |
+
"torch_dtype": "float32",
|
143 |
+
"transformers_version": "4.49.0",
|
144 |
+
"use_weighted_layer_sum": false,
|
145 |
+
"vocab_size": 32,
|
146 |
+
"xvector_output_dim": 512
|
147 |
+
}
|
hpo-examples/audio-classification/ac/eval_results.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"epoch": 7.972769953051643,
|
3 |
+
"eval_accuracy": 0.9826419535157399,
|
4 |
+
"eval_loss": 0.09542840719223022,
|
5 |
+
"eval_runtime": 5.5538,
|
6 |
+
"eval_samples_per_second": 1224.023,
|
7 |
+
"eval_steps_per_second": 38.352
|
8 |
+
}
|
hpo-examples/audio-classification/ac/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d6e4f1c85d883f3e41ebfab4cd7752ab2e6d6b968b847795be22e1e0662657a3
|
3 |
+
size 385400352
|
hpo-examples/audio-classification/ac/preprocessor_config.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_normalize": true,
|
3 |
+
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
|
4 |
+
"feature_size": 1,
|
5 |
+
"padding_side": "right",
|
6 |
+
"padding_value": 0.0,
|
7 |
+
"return_attention_mask": false,
|
8 |
+
"sampling_rate": 16000
|
9 |
+
}
|
hpo-examples/audio-classification/ac/runs/May15_03-06-03_cs-Precision-7960-Tower/events.out.tfevents.1747292768.cs-Precision-7960-Tower.146737.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:75015326d15cb1786d8788b228e95bd7e952bad9e8e8be4ed02764cc17c8464c
|
3 |
+
size 54953
|
hpo-examples/audio-classification/ac/runs/May15_03-06-03_cs-Precision-7960-Tower/events.out.tfevents.1747293535.cs-Precision-7960-Tower.146737.1
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:215a99aa1f96db505ff53265c988d18ddcf70f16a1f13e713316a0354b768356
|
3 |
+
size 411
|
hpo-examples/audio-classification/ac/train_results.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"epoch": 7.972769953051643,
|
3 |
+
"total_flos": 3.767900833756416e+18,
|
4 |
+
"train_loss": 0.5178930132572812,
|
5 |
+
"train_runtime": 756.2923,
|
6 |
+
"train_samples_per_second": 540.468,
|
7 |
+
"train_steps_per_second": 2.814
|
8 |
+
}
|
hpo-examples/audio-classification/ac/trainer_state.json
ADDED
@@ -0,0 +1,1598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"best_metric": 0.9826419535157399,
|
3 |
+
"best_model_checkpoint": "wav2vec2-base-ft-keyword-spotting/checkpoint-1335",
|
4 |
+
"epoch": 7.972769953051643,
|
5 |
+
"eval_steps": 500,
|
6 |
+
"global_step": 2128,
|
7 |
+
"is_hyper_param_search": false,
|
8 |
+
"is_local_process_zero": true,
|
9 |
+
"is_world_process_zero": true,
|
10 |
+
"log_history": [
|
11 |
+
{
|
12 |
+
"epoch": 0.03755868544600939,
|
13 |
+
"grad_norm": 2.0377416610717773,
|
14 |
+
"learning_rate": 1.4084507042253521e-06,
|
15 |
+
"loss": 3.8687,
|
16 |
+
"step": 10
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"epoch": 0.07511737089201878,
|
20 |
+
"grad_norm": 3.055781602859497,
|
21 |
+
"learning_rate": 2.8169014084507042e-06,
|
22 |
+
"loss": 4.1156,
|
23 |
+
"step": 20
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"epoch": 0.11267605633802817,
|
27 |
+
"grad_norm": 3.383268356323242,
|
28 |
+
"learning_rate": 4.225352112676057e-06,
|
29 |
+
"loss": 4.0885,
|
30 |
+
"step": 30
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"epoch": 0.15023474178403756,
|
34 |
+
"grad_norm": 3.8566606044769287,
|
35 |
+
"learning_rate": 5.6338028169014084e-06,
|
36 |
+
"loss": 3.9316,
|
37 |
+
"step": 40
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"epoch": 0.18779342723004694,
|
41 |
+
"grad_norm": 5.065456867218018,
|
42 |
+
"learning_rate": 7.042253521126761e-06,
|
43 |
+
"loss": 3.6474,
|
44 |
+
"step": 50
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"epoch": 0.22535211267605634,
|
48 |
+
"grad_norm": 5.89341926574707,
|
49 |
+
"learning_rate": 8.450704225352114e-06,
|
50 |
+
"loss": 3.2124,
|
51 |
+
"step": 60
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"epoch": 0.26291079812206575,
|
55 |
+
"grad_norm": 5.9929399490356445,
|
56 |
+
"learning_rate": 9.859154929577466e-06,
|
57 |
+
"loss": 2.756,
|
58 |
+
"step": 70
|
59 |
+
},
|
60 |
+
{
|
61 |
+
"epoch": 0.3004694835680751,
|
62 |
+
"grad_norm": 5.689433574676514,
|
63 |
+
"learning_rate": 1.1267605633802817e-05,
|
64 |
+
"loss": 2.4596,
|
65 |
+
"step": 80
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"epoch": 0.3380281690140845,
|
69 |
+
"grad_norm": 4.89589262008667,
|
70 |
+
"learning_rate": 1.267605633802817e-05,
|
71 |
+
"loss": 2.2638,
|
72 |
+
"step": 90
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"epoch": 0.3755868544600939,
|
76 |
+
"grad_norm": 4.8666839599609375,
|
77 |
+
"learning_rate": 1.4084507042253522e-05,
|
78 |
+
"loss": 2.1166,
|
79 |
+
"step": 100
|
80 |
+
},
|
81 |
+
{
|
82 |
+
"epoch": 0.4131455399061033,
|
83 |
+
"grad_norm": 4.466708660125732,
|
84 |
+
"learning_rate": 1.5492957746478876e-05,
|
85 |
+
"loss": 2.0048,
|
86 |
+
"step": 110
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"epoch": 0.4507042253521127,
|
90 |
+
"grad_norm": 3.676050901412964,
|
91 |
+
"learning_rate": 1.6901408450704228e-05,
|
92 |
+
"loss": 1.9138,
|
93 |
+
"step": 120
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"epoch": 0.48826291079812206,
|
97 |
+
"grad_norm": 2.183825731277466,
|
98 |
+
"learning_rate": 1.830985915492958e-05,
|
99 |
+
"loss": 1.863,
|
100 |
+
"step": 130
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"epoch": 0.5258215962441315,
|
104 |
+
"grad_norm": 2.075413465499878,
|
105 |
+
"learning_rate": 1.9718309859154933e-05,
|
106 |
+
"loss": 1.7616,
|
107 |
+
"step": 140
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"epoch": 0.5633802816901409,
|
111 |
+
"grad_norm": 0.8534318208694458,
|
112 |
+
"learning_rate": 2.112676056338028e-05,
|
113 |
+
"loss": 1.7185,
|
114 |
+
"step": 150
|
115 |
+
},
|
116 |
+
{
|
117 |
+
"epoch": 0.6009389671361502,
|
118 |
+
"grad_norm": 0.9039830565452576,
|
119 |
+
"learning_rate": 2.2535211267605634e-05,
|
120 |
+
"loss": 1.8054,
|
121 |
+
"step": 160
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"epoch": 0.6384976525821596,
|
125 |
+
"grad_norm": 1.32124662399292,
|
126 |
+
"learning_rate": 2.3943661971830986e-05,
|
127 |
+
"loss": 1.7367,
|
128 |
+
"step": 170
|
129 |
+
},
|
130 |
+
{
|
131 |
+
"epoch": 0.676056338028169,
|
132 |
+
"grad_norm": 1.232069969177246,
|
133 |
+
"learning_rate": 2.535211267605634e-05,
|
134 |
+
"loss": 1.7423,
|
135 |
+
"step": 180
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"epoch": 0.7136150234741784,
|
139 |
+
"grad_norm": 1.9570960998535156,
|
140 |
+
"learning_rate": 2.676056338028169e-05,
|
141 |
+
"loss": 1.6132,
|
142 |
+
"step": 190
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"epoch": 0.7511737089201878,
|
146 |
+
"grad_norm": 2.4463119506835938,
|
147 |
+
"learning_rate": 2.8169014084507043e-05,
|
148 |
+
"loss": 1.6099,
|
149 |
+
"step": 200
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"epoch": 0.7887323943661971,
|
153 |
+
"grad_norm": 6.601908206939697,
|
154 |
+
"learning_rate": 2.9577464788732395e-05,
|
155 |
+
"loss": 1.6043,
|
156 |
+
"step": 210
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"epoch": 0.8262910798122066,
|
160 |
+
"grad_norm": 3.225101947784424,
|
161 |
+
"learning_rate": 2.989033942558747e-05,
|
162 |
+
"loss": 1.5621,
|
163 |
+
"step": 220
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"epoch": 0.863849765258216,
|
167 |
+
"grad_norm": 3.698263645172119,
|
168 |
+
"learning_rate": 2.9733681462140994e-05,
|
169 |
+
"loss": 1.514,
|
170 |
+
"step": 230
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"epoch": 0.9014084507042254,
|
174 |
+
"grad_norm": 5.209756374359131,
|
175 |
+
"learning_rate": 2.9577023498694518e-05,
|
176 |
+
"loss": 1.4532,
|
177 |
+
"step": 240
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"epoch": 0.9389671361502347,
|
181 |
+
"grad_norm": 2.1304848194122314,
|
182 |
+
"learning_rate": 2.9420365535248042e-05,
|
183 |
+
"loss": 1.4312,
|
184 |
+
"step": 250
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"epoch": 0.9765258215962441,
|
188 |
+
"grad_norm": 4.837350368499756,
|
189 |
+
"learning_rate": 2.926370757180157e-05,
|
190 |
+
"loss": 1.3624,
|
191 |
+
"step": 260
|
192 |
+
},
|
193 |
+
{
|
194 |
+
"epoch": 1.0,
|
195 |
+
"eval_accuracy": 0.6546042953809944,
|
196 |
+
"eval_loss": 1.19585382938385,
|
197 |
+
"eval_runtime": 4.9178,
|
198 |
+
"eval_samples_per_second": 1382.328,
|
199 |
+
"eval_steps_per_second": 43.312,
|
200 |
+
"step": 267
|
201 |
+
},
|
202 |
+
{
|
203 |
+
"epoch": 1.0112676056338028,
|
204 |
+
"grad_norm": 4.779292106628418,
|
205 |
+
"learning_rate": 2.9107049608355094e-05,
|
206 |
+
"loss": 1.2541,
|
207 |
+
"step": 270
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"epoch": 1.0488262910798123,
|
211 |
+
"grad_norm": 3.60760498046875,
|
212 |
+
"learning_rate": 2.8950391644908618e-05,
|
213 |
+
"loss": 1.2271,
|
214 |
+
"step": 280
|
215 |
+
},
|
216 |
+
{
|
217 |
+
"epoch": 1.0863849765258216,
|
218 |
+
"grad_norm": 2.3788599967956543,
|
219 |
+
"learning_rate": 2.8793733681462142e-05,
|
220 |
+
"loss": 1.2335,
|
221 |
+
"step": 290
|
222 |
+
},
|
223 |
+
{
|
224 |
+
"epoch": 1.123943661971831,
|
225 |
+
"grad_norm": 3.353325843811035,
|
226 |
+
"learning_rate": 2.8637075718015666e-05,
|
227 |
+
"loss": 1.1613,
|
228 |
+
"step": 300
|
229 |
+
},
|
230 |
+
{
|
231 |
+
"epoch": 1.1615023474178403,
|
232 |
+
"grad_norm": 4.326411247253418,
|
233 |
+
"learning_rate": 2.8480417754569193e-05,
|
234 |
+
"loss": 1.0754,
|
235 |
+
"step": 310
|
236 |
+
},
|
237 |
+
{
|
238 |
+
"epoch": 1.1990610328638498,
|
239 |
+
"grad_norm": 3.1939706802368164,
|
240 |
+
"learning_rate": 2.8323759791122717e-05,
|
241 |
+
"loss": 1.0353,
|
242 |
+
"step": 320
|
243 |
+
},
|
244 |
+
{
|
245 |
+
"epoch": 1.236619718309859,
|
246 |
+
"grad_norm": 2.8827011585235596,
|
247 |
+
"learning_rate": 2.816710182767624e-05,
|
248 |
+
"loss": 0.9806,
|
249 |
+
"step": 330
|
250 |
+
},
|
251 |
+
{
|
252 |
+
"epoch": 1.2741784037558685,
|
253 |
+
"grad_norm": 3.910698652267456,
|
254 |
+
"learning_rate": 2.8010443864229766e-05,
|
255 |
+
"loss": 1.0813,
|
256 |
+
"step": 340
|
257 |
+
},
|
258 |
+
{
|
259 |
+
"epoch": 1.3117370892018778,
|
260 |
+
"grad_norm": 3.5916378498077393,
|
261 |
+
"learning_rate": 2.7853785900783293e-05,
|
262 |
+
"loss": 0.9792,
|
263 |
+
"step": 350
|
264 |
+
},
|
265 |
+
{
|
266 |
+
"epoch": 1.3492957746478873,
|
267 |
+
"grad_norm": 2.6981167793273926,
|
268 |
+
"learning_rate": 2.7697127937336817e-05,
|
269 |
+
"loss": 0.9231,
|
270 |
+
"step": 360
|
271 |
+
},
|
272 |
+
{
|
273 |
+
"epoch": 1.3868544600938968,
|
274 |
+
"grad_norm": 5.702897071838379,
|
275 |
+
"learning_rate": 2.754046997389034e-05,
|
276 |
+
"loss": 0.9435,
|
277 |
+
"step": 370
|
278 |
+
},
|
279 |
+
{
|
280 |
+
"epoch": 1.424413145539906,
|
281 |
+
"grad_norm": 4.622363090515137,
|
282 |
+
"learning_rate": 2.7383812010443865e-05,
|
283 |
+
"loss": 0.8449,
|
284 |
+
"step": 380
|
285 |
+
},
|
286 |
+
{
|
287 |
+
"epoch": 1.4619718309859155,
|
288 |
+
"grad_norm": 2.2103636264801025,
|
289 |
+
"learning_rate": 2.7227154046997393e-05,
|
290 |
+
"loss": 0.7713,
|
291 |
+
"step": 390
|
292 |
+
},
|
293 |
+
{
|
294 |
+
"epoch": 1.4995305164319248,
|
295 |
+
"grad_norm": 4.545182228088379,
|
296 |
+
"learning_rate": 2.7070496083550917e-05,
|
297 |
+
"loss": 0.7719,
|
298 |
+
"step": 400
|
299 |
+
},
|
300 |
+
{
|
301 |
+
"epoch": 1.5370892018779343,
|
302 |
+
"grad_norm": 6.883026599884033,
|
303 |
+
"learning_rate": 2.691383812010444e-05,
|
304 |
+
"loss": 0.7564,
|
305 |
+
"step": 410
|
306 |
+
},
|
307 |
+
{
|
308 |
+
"epoch": 1.5746478873239438,
|
309 |
+
"grad_norm": 4.770920276641846,
|
310 |
+
"learning_rate": 2.6757180156657965e-05,
|
311 |
+
"loss": 0.6994,
|
312 |
+
"step": 420
|
313 |
+
},
|
314 |
+
{
|
315 |
+
"epoch": 1.612206572769953,
|
316 |
+
"grad_norm": 4.413459300994873,
|
317 |
+
"learning_rate": 2.660052219321149e-05,
|
318 |
+
"loss": 0.6313,
|
319 |
+
"step": 430
|
320 |
+
},
|
321 |
+
{
|
322 |
+
"epoch": 1.6497652582159623,
|
323 |
+
"grad_norm": 2.0261390209198,
|
324 |
+
"learning_rate": 2.6443864229765013e-05,
|
325 |
+
"loss": 0.6017,
|
326 |
+
"step": 440
|
327 |
+
},
|
328 |
+
{
|
329 |
+
"epoch": 1.6873239436619718,
|
330 |
+
"grad_norm": 5.67121696472168,
|
331 |
+
"learning_rate": 2.6287206266318537e-05,
|
332 |
+
"loss": 0.5792,
|
333 |
+
"step": 450
|
334 |
+
},
|
335 |
+
{
|
336 |
+
"epoch": 1.7248826291079813,
|
337 |
+
"grad_norm": 2.573594808578491,
|
338 |
+
"learning_rate": 2.6146214099216712e-05,
|
339 |
+
"loss": 0.545,
|
340 |
+
"step": 460
|
341 |
+
},
|
342 |
+
{
|
343 |
+
"epoch": 1.7624413145539906,
|
344 |
+
"grad_norm": 4.145854949951172,
|
345 |
+
"learning_rate": 2.5989556135770236e-05,
|
346 |
+
"loss": 0.4907,
|
347 |
+
"step": 470
|
348 |
+
},
|
349 |
+
{
|
350 |
+
"epoch": 1.8,
|
351 |
+
"grad_norm": 1.7418975830078125,
|
352 |
+
"learning_rate": 2.583289817232376e-05,
|
353 |
+
"loss": 0.485,
|
354 |
+
"step": 480
|
355 |
+
},
|
356 |
+
{
|
357 |
+
"epoch": 1.8375586854460093,
|
358 |
+
"grad_norm": 4.651867866516113,
|
359 |
+
"learning_rate": 2.5676240208877287e-05,
|
360 |
+
"loss": 0.4572,
|
361 |
+
"step": 490
|
362 |
+
},
|
363 |
+
{
|
364 |
+
"epoch": 1.8751173708920188,
|
365 |
+
"grad_norm": 4.849829196929932,
|
366 |
+
"learning_rate": 2.551958224543081e-05,
|
367 |
+
"loss": 0.4864,
|
368 |
+
"step": 500
|
369 |
+
},
|
370 |
+
{
|
371 |
+
"epoch": 1.9126760563380283,
|
372 |
+
"grad_norm": 2.631229877471924,
|
373 |
+
"learning_rate": 2.5362924281984335e-05,
|
374 |
+
"loss": 0.4035,
|
375 |
+
"step": 510
|
376 |
+
},
|
377 |
+
{
|
378 |
+
"epoch": 1.9502347417840376,
|
379 |
+
"grad_norm": 5.099828243255615,
|
380 |
+
"learning_rate": 2.520626631853786e-05,
|
381 |
+
"loss": 0.3818,
|
382 |
+
"step": 520
|
383 |
+
},
|
384 |
+
{
|
385 |
+
"epoch": 1.9877934272300468,
|
386 |
+
"grad_norm": 3.25174617767334,
|
387 |
+
"learning_rate": 2.5049608355091387e-05,
|
388 |
+
"loss": 0.3854,
|
389 |
+
"step": 530
|
390 |
+
},
|
391 |
+
{
|
392 |
+
"epoch": 2.0,
|
393 |
+
"eval_accuracy": 0.9733745219182113,
|
394 |
+
"eval_loss": 0.2675245702266693,
|
395 |
+
"eval_runtime": 5.0739,
|
396 |
+
"eval_samples_per_second": 1339.801,
|
397 |
+
"eval_steps_per_second": 41.98,
|
398 |
+
"step": 534
|
399 |
+
},
|
400 |
+
{
|
401 |
+
"epoch": 2.0225352112676056,
|
402 |
+
"grad_norm": 4.533545017242432,
|
403 |
+
"learning_rate": 2.489295039164491e-05,
|
404 |
+
"loss": 0.3589,
|
405 |
+
"step": 540
|
406 |
+
},
|
407 |
+
{
|
408 |
+
"epoch": 2.060093896713615,
|
409 |
+
"grad_norm": 4.4245991706848145,
|
410 |
+
"learning_rate": 2.4736292428198435e-05,
|
411 |
+
"loss": 0.378,
|
412 |
+
"step": 550
|
413 |
+
},
|
414 |
+
{
|
415 |
+
"epoch": 2.0976525821596246,
|
416 |
+
"grad_norm": 5.778880596160889,
|
417 |
+
"learning_rate": 2.457963446475196e-05,
|
418 |
+
"loss": 0.3653,
|
419 |
+
"step": 560
|
420 |
+
},
|
421 |
+
{
|
422 |
+
"epoch": 2.1352112676056336,
|
423 |
+
"grad_norm": 3.5573890209198,
|
424 |
+
"learning_rate": 2.4422976501305487e-05,
|
425 |
+
"loss": 0.3107,
|
426 |
+
"step": 570
|
427 |
+
},
|
428 |
+
{
|
429 |
+
"epoch": 2.172769953051643,
|
430 |
+
"grad_norm": 3.655824899673462,
|
431 |
+
"learning_rate": 2.426631853785901e-05,
|
432 |
+
"loss": 0.3405,
|
433 |
+
"step": 580
|
434 |
+
},
|
435 |
+
{
|
436 |
+
"epoch": 2.2103286384976526,
|
437 |
+
"grad_norm": 2.430022954940796,
|
438 |
+
"learning_rate": 2.4109660574412535e-05,
|
439 |
+
"loss": 0.3298,
|
440 |
+
"step": 590
|
441 |
+
},
|
442 |
+
{
|
443 |
+
"epoch": 2.247887323943662,
|
444 |
+
"grad_norm": 2.9207568168640137,
|
445 |
+
"learning_rate": 2.3953002610966055e-05,
|
446 |
+
"loss": 0.3022,
|
447 |
+
"step": 600
|
448 |
+
},
|
449 |
+
{
|
450 |
+
"epoch": 2.2854460093896716,
|
451 |
+
"grad_norm": 4.8787007331848145,
|
452 |
+
"learning_rate": 2.3796344647519583e-05,
|
453 |
+
"loss": 0.3991,
|
454 |
+
"step": 610
|
455 |
+
},
|
456 |
+
{
|
457 |
+
"epoch": 2.3230046948356806,
|
458 |
+
"grad_norm": 3.0268468856811523,
|
459 |
+
"learning_rate": 2.3639686684073107e-05,
|
460 |
+
"loss": 0.3159,
|
461 |
+
"step": 620
|
462 |
+
},
|
463 |
+
{
|
464 |
+
"epoch": 2.36056338028169,
|
465 |
+
"grad_norm": 2.6611557006835938,
|
466 |
+
"learning_rate": 2.348302872062663e-05,
|
467 |
+
"loss": 0.2868,
|
468 |
+
"step": 630
|
469 |
+
},
|
470 |
+
{
|
471 |
+
"epoch": 2.3981220657276996,
|
472 |
+
"grad_norm": 2.485551595687866,
|
473 |
+
"learning_rate": 2.3326370757180155e-05,
|
474 |
+
"loss": 0.3032,
|
475 |
+
"step": 640
|
476 |
+
},
|
477 |
+
{
|
478 |
+
"epoch": 2.435680751173709,
|
479 |
+
"grad_norm": 4.556153297424316,
|
480 |
+
"learning_rate": 2.316971279373368e-05,
|
481 |
+
"loss": 0.2985,
|
482 |
+
"step": 650
|
483 |
+
},
|
484 |
+
{
|
485 |
+
"epoch": 2.473239436619718,
|
486 |
+
"grad_norm": 5.270796298980713,
|
487 |
+
"learning_rate": 2.3013054830287207e-05,
|
488 |
+
"loss": 0.2839,
|
489 |
+
"step": 660
|
490 |
+
},
|
491 |
+
{
|
492 |
+
"epoch": 2.5107981220657276,
|
493 |
+
"grad_norm": 3.347005844116211,
|
494 |
+
"learning_rate": 2.285639686684073e-05,
|
495 |
+
"loss": 0.2871,
|
496 |
+
"step": 670
|
497 |
+
},
|
498 |
+
{
|
499 |
+
"epoch": 2.548356807511737,
|
500 |
+
"grad_norm": 5.236591815948486,
|
501 |
+
"learning_rate": 2.2699738903394255e-05,
|
502 |
+
"loss": 0.3028,
|
503 |
+
"step": 680
|
504 |
+
},
|
505 |
+
{
|
506 |
+
"epoch": 2.5859154929577466,
|
507 |
+
"grad_norm": 2.995059013366699,
|
508 |
+
"learning_rate": 2.254308093994778e-05,
|
509 |
+
"loss": 0.2537,
|
510 |
+
"step": 690
|
511 |
+
},
|
512 |
+
{
|
513 |
+
"epoch": 2.6234741784037556,
|
514 |
+
"grad_norm": 2.805640459060669,
|
515 |
+
"learning_rate": 2.2386422976501306e-05,
|
516 |
+
"loss": 0.297,
|
517 |
+
"step": 700
|
518 |
+
},
|
519 |
+
{
|
520 |
+
"epoch": 2.661032863849765,
|
521 |
+
"grad_norm": 3.0646071434020996,
|
522 |
+
"learning_rate": 2.222976501305483e-05,
|
523 |
+
"loss": 0.2453,
|
524 |
+
"step": 710
|
525 |
+
},
|
526 |
+
{
|
527 |
+
"epoch": 2.6985915492957746,
|
528 |
+
"grad_norm": 3.6719613075256348,
|
529 |
+
"learning_rate": 2.2073107049608354e-05,
|
530 |
+
"loss": 0.2655,
|
531 |
+
"step": 720
|
532 |
+
},
|
533 |
+
{
|
534 |
+
"epoch": 2.736150234741784,
|
535 |
+
"grad_norm": 3.2248122692108154,
|
536 |
+
"learning_rate": 2.191644908616188e-05,
|
537 |
+
"loss": 0.2297,
|
538 |
+
"step": 730
|
539 |
+
},
|
540 |
+
{
|
541 |
+
"epoch": 2.7737089201877936,
|
542 |
+
"grad_norm": 3.769843578338623,
|
543 |
+
"learning_rate": 2.1759791122715406e-05,
|
544 |
+
"loss": 0.2548,
|
545 |
+
"step": 740
|
546 |
+
},
|
547 |
+
{
|
548 |
+
"epoch": 2.8112676056338026,
|
549 |
+
"grad_norm": 3.6679906845092773,
|
550 |
+
"learning_rate": 2.160313315926893e-05,
|
551 |
+
"loss": 0.2836,
|
552 |
+
"step": 750
|
553 |
+
},
|
554 |
+
{
|
555 |
+
"epoch": 2.848826291079812,
|
556 |
+
"grad_norm": 1.6924936771392822,
|
557 |
+
"learning_rate": 2.1446475195822454e-05,
|
558 |
+
"loss": 0.2555,
|
559 |
+
"step": 760
|
560 |
+
},
|
561 |
+
{
|
562 |
+
"epoch": 2.8863849765258216,
|
563 |
+
"grad_norm": 2.1275901794433594,
|
564 |
+
"learning_rate": 2.1289817232375978e-05,
|
565 |
+
"loss": 0.2334,
|
566 |
+
"step": 770
|
567 |
+
},
|
568 |
+
{
|
569 |
+
"epoch": 2.923943661971831,
|
570 |
+
"grad_norm": 6.528135299682617,
|
571 |
+
"learning_rate": 2.1133159268929506e-05,
|
572 |
+
"loss": 0.2544,
|
573 |
+
"step": 780
|
574 |
+
},
|
575 |
+
{
|
576 |
+
"epoch": 2.9615023474178406,
|
577 |
+
"grad_norm": 2.4497199058532715,
|
578 |
+
"learning_rate": 2.097650130548303e-05,
|
579 |
+
"loss": 0.2628,
|
580 |
+
"step": 790
|
581 |
+
},
|
582 |
+
{
|
583 |
+
"epoch": 2.9990610328638496,
|
584 |
+
"grad_norm": 2.278947591781616,
|
585 |
+
"learning_rate": 2.0819843342036554e-05,
|
586 |
+
"loss": 0.2473,
|
587 |
+
"step": 800
|
588 |
+
},
|
589 |
+
{
|
590 |
+
"epoch": 3.0,
|
591 |
+
"eval_accuracy": 0.9767578699617535,
|
592 |
+
"eval_loss": 0.1461225152015686,
|
593 |
+
"eval_runtime": 5.0057,
|
594 |
+
"eval_samples_per_second": 1358.045,
|
595 |
+
"eval_steps_per_second": 42.551,
|
596 |
+
"step": 801
|
597 |
+
},
|
598 |
+
{
|
599 |
+
"epoch": 3.0338028169014084,
|
600 |
+
"grad_norm": 3.1185402870178223,
|
601 |
+
"learning_rate": 2.0663185378590078e-05,
|
602 |
+
"loss": 0.2245,
|
603 |
+
"step": 810
|
604 |
+
},
|
605 |
+
{
|
606 |
+
"epoch": 3.071361502347418,
|
607 |
+
"grad_norm": 2.456102132797241,
|
608 |
+
"learning_rate": 2.0506527415143602e-05,
|
609 |
+
"loss": 0.2423,
|
610 |
+
"step": 820
|
611 |
+
},
|
612 |
+
{
|
613 |
+
"epoch": 3.1089201877934274,
|
614 |
+
"grad_norm": 2.9463231563568115,
|
615 |
+
"learning_rate": 2.034986945169713e-05,
|
616 |
+
"loss": 0.2274,
|
617 |
+
"step": 830
|
618 |
+
},
|
619 |
+
{
|
620 |
+
"epoch": 3.1464788732394364,
|
621 |
+
"grad_norm": 3.5940473079681396,
|
622 |
+
"learning_rate": 2.0193211488250653e-05,
|
623 |
+
"loss": 0.2368,
|
624 |
+
"step": 840
|
625 |
+
},
|
626 |
+
{
|
627 |
+
"epoch": 3.184037558685446,
|
628 |
+
"grad_norm": 4.721577167510986,
|
629 |
+
"learning_rate": 2.0036553524804177e-05,
|
630 |
+
"loss": 0.2554,
|
631 |
+
"step": 850
|
632 |
+
},
|
633 |
+
{
|
634 |
+
"epoch": 3.2215962441314554,
|
635 |
+
"grad_norm": 2.496495485305786,
|
636 |
+
"learning_rate": 1.98798955613577e-05,
|
637 |
+
"loss": 0.2363,
|
638 |
+
"step": 860
|
639 |
+
},
|
640 |
+
{
|
641 |
+
"epoch": 3.259154929577465,
|
642 |
+
"grad_norm": 3.0665740966796875,
|
643 |
+
"learning_rate": 1.972323759791123e-05,
|
644 |
+
"loss": 0.2248,
|
645 |
+
"step": 870
|
646 |
+
},
|
647 |
+
{
|
648 |
+
"epoch": 3.2967136150234744,
|
649 |
+
"grad_norm": 4.336172580718994,
|
650 |
+
"learning_rate": 1.9566579634464753e-05,
|
651 |
+
"loss": 0.1922,
|
652 |
+
"step": 880
|
653 |
+
},
|
654 |
+
{
|
655 |
+
"epoch": 3.3342723004694834,
|
656 |
+
"grad_norm": 4.110763072967529,
|
657 |
+
"learning_rate": 1.9409921671018277e-05,
|
658 |
+
"loss": 0.1965,
|
659 |
+
"step": 890
|
660 |
+
},
|
661 |
+
{
|
662 |
+
"epoch": 3.371830985915493,
|
663 |
+
"grad_norm": 1.9457247257232666,
|
664 |
+
"learning_rate": 1.92532637075718e-05,
|
665 |
+
"loss": 0.2258,
|
666 |
+
"step": 900
|
667 |
+
},
|
668 |
+
{
|
669 |
+
"epoch": 3.4093896713615024,
|
670 |
+
"grad_norm": 2.719369411468506,
|
671 |
+
"learning_rate": 1.909660574412533e-05,
|
672 |
+
"loss": 0.2184,
|
673 |
+
"step": 910
|
674 |
+
},
|
675 |
+
{
|
676 |
+
"epoch": 3.446948356807512,
|
677 |
+
"grad_norm": 3.438279151916504,
|
678 |
+
"learning_rate": 1.8939947780678853e-05,
|
679 |
+
"loss": 0.1964,
|
680 |
+
"step": 920
|
681 |
+
},
|
682 |
+
{
|
683 |
+
"epoch": 3.4845070422535214,
|
684 |
+
"grad_norm": 3.2813045978546143,
|
685 |
+
"learning_rate": 1.8783289817232377e-05,
|
686 |
+
"loss": 0.2348,
|
687 |
+
"step": 930
|
688 |
+
},
|
689 |
+
{
|
690 |
+
"epoch": 3.5220657276995304,
|
691 |
+
"grad_norm": 4.151478290557861,
|
692 |
+
"learning_rate": 1.86266318537859e-05,
|
693 |
+
"loss": 0.2004,
|
694 |
+
"step": 940
|
695 |
+
},
|
696 |
+
{
|
697 |
+
"epoch": 3.55962441314554,
|
698 |
+
"grad_norm": 3.4271771907806396,
|
699 |
+
"learning_rate": 1.8469973890339425e-05,
|
700 |
+
"loss": 0.2039,
|
701 |
+
"step": 950
|
702 |
+
},
|
703 |
+
{
|
704 |
+
"epoch": 3.5971830985915494,
|
705 |
+
"grad_norm": 4.0341901779174805,
|
706 |
+
"learning_rate": 1.8313315926892952e-05,
|
707 |
+
"loss": 0.1997,
|
708 |
+
"step": 960
|
709 |
+
},
|
710 |
+
{
|
711 |
+
"epoch": 3.6347417840375584,
|
712 |
+
"grad_norm": 4.762091636657715,
|
713 |
+
"learning_rate": 1.8156657963446476e-05,
|
714 |
+
"loss": 0.2153,
|
715 |
+
"step": 970
|
716 |
+
},
|
717 |
+
{
|
718 |
+
"epoch": 3.672300469483568,
|
719 |
+
"grad_norm": 3.3214402198791504,
|
720 |
+
"learning_rate": 1.8e-05,
|
721 |
+
"loss": 0.1801,
|
722 |
+
"step": 980
|
723 |
+
},
|
724 |
+
{
|
725 |
+
"epoch": 3.7098591549295774,
|
726 |
+
"grad_norm": 3.84503173828125,
|
727 |
+
"learning_rate": 1.7843342036553525e-05,
|
728 |
+
"loss": 0.2106,
|
729 |
+
"step": 990
|
730 |
+
},
|
731 |
+
{
|
732 |
+
"epoch": 3.747417840375587,
|
733 |
+
"grad_norm": 3.303781747817993,
|
734 |
+
"learning_rate": 1.7686684073107052e-05,
|
735 |
+
"loss": 0.1965,
|
736 |
+
"step": 1000
|
737 |
+
},
|
738 |
+
{
|
739 |
+
"epoch": 3.7849765258215964,
|
740 |
+
"grad_norm": 2.691159248352051,
|
741 |
+
"learning_rate": 1.7530026109660576e-05,
|
742 |
+
"loss": 0.193,
|
743 |
+
"step": 1010
|
744 |
+
},
|
745 |
+
{
|
746 |
+
"epoch": 3.8225352112676054,
|
747 |
+
"grad_norm": 4.134768009185791,
|
748 |
+
"learning_rate": 1.73733681462141e-05,
|
749 |
+
"loss": 0.1908,
|
750 |
+
"step": 1020
|
751 |
+
},
|
752 |
+
{
|
753 |
+
"epoch": 3.860093896713615,
|
754 |
+
"grad_norm": 2.9195241928100586,
|
755 |
+
"learning_rate": 1.7216710182767624e-05,
|
756 |
+
"loss": 0.1886,
|
757 |
+
"step": 1030
|
758 |
+
},
|
759 |
+
{
|
760 |
+
"epoch": 3.8976525821596244,
|
761 |
+
"grad_norm": 3.795133352279663,
|
762 |
+
"learning_rate": 1.706005221932115e-05,
|
763 |
+
"loss": 0.2007,
|
764 |
+
"step": 1040
|
765 |
+
},
|
766 |
+
{
|
767 |
+
"epoch": 3.935211267605634,
|
768 |
+
"grad_norm": 3.9436607360839844,
|
769 |
+
"learning_rate": 1.6903394255874676e-05,
|
770 |
+
"loss": 0.1834,
|
771 |
+
"step": 1050
|
772 |
+
},
|
773 |
+
{
|
774 |
+
"epoch": 3.9727699530516434,
|
775 |
+
"grad_norm": 3.4115564823150635,
|
776 |
+
"learning_rate": 1.67467362924282e-05,
|
777 |
+
"loss": 0.1997,
|
778 |
+
"step": 1060
|
779 |
+
},
|
780 |
+
{
|
781 |
+
"epoch": 4.0,
|
782 |
+
"eval_accuracy": 0.980435422182995,
|
783 |
+
"eval_loss": 0.10877315700054169,
|
784 |
+
"eval_runtime": 4.9191,
|
785 |
+
"eval_samples_per_second": 1381.955,
|
786 |
+
"eval_steps_per_second": 43.3,
|
787 |
+
"step": 1068
|
788 |
+
},
|
789 |
+
{
|
790 |
+
"epoch": 4.007511737089202,
|
791 |
+
"grad_norm": 5.121041774749756,
|
792 |
+
"learning_rate": 1.6590078328981724e-05,
|
793 |
+
"loss": 0.1785,
|
794 |
+
"step": 1070
|
795 |
+
},
|
796 |
+
{
|
797 |
+
"epoch": 4.045070422535211,
|
798 |
+
"grad_norm": 2.908527374267578,
|
799 |
+
"learning_rate": 1.643342036553525e-05,
|
800 |
+
"loss": 0.1678,
|
801 |
+
"step": 1080
|
802 |
+
},
|
803 |
+
{
|
804 |
+
"epoch": 4.08262910798122,
|
805 |
+
"grad_norm": 1.9687402248382568,
|
806 |
+
"learning_rate": 1.6276762402088775e-05,
|
807 |
+
"loss": 0.192,
|
808 |
+
"step": 1090
|
809 |
+
},
|
810 |
+
{
|
811 |
+
"epoch": 4.12018779342723,
|
812 |
+
"grad_norm": 2.722937822341919,
|
813 |
+
"learning_rate": 1.61201044386423e-05,
|
814 |
+
"loss": 0.1983,
|
815 |
+
"step": 1100
|
816 |
+
},
|
817 |
+
{
|
818 |
+
"epoch": 4.157746478873239,
|
819 |
+
"grad_norm": 2.3741490840911865,
|
820 |
+
"learning_rate": 1.5963446475195823e-05,
|
821 |
+
"loss": 0.2145,
|
822 |
+
"step": 1110
|
823 |
+
},
|
824 |
+
{
|
825 |
+
"epoch": 4.195305164319249,
|
826 |
+
"grad_norm": 2.653414011001587,
|
827 |
+
"learning_rate": 1.5806788511749348e-05,
|
828 |
+
"loss": 0.1701,
|
829 |
+
"step": 1120
|
830 |
+
},
|
831 |
+
{
|
832 |
+
"epoch": 4.232863849765258,
|
833 |
+
"grad_norm": 3.444087266921997,
|
834 |
+
"learning_rate": 1.5650130548302875e-05,
|
835 |
+
"loss": 0.2047,
|
836 |
+
"step": 1130
|
837 |
+
},
|
838 |
+
{
|
839 |
+
"epoch": 4.270422535211267,
|
840 |
+
"grad_norm": 2.024235486984253,
|
841 |
+
"learning_rate": 1.54934725848564e-05,
|
842 |
+
"loss": 0.1817,
|
843 |
+
"step": 1140
|
844 |
+
},
|
845 |
+
{
|
846 |
+
"epoch": 4.307981220657277,
|
847 |
+
"grad_norm": 2.742171049118042,
|
848 |
+
"learning_rate": 1.533681462140992e-05,
|
849 |
+
"loss": 0.1723,
|
850 |
+
"step": 1150
|
851 |
+
},
|
852 |
+
{
|
853 |
+
"epoch": 4.345539906103286,
|
854 |
+
"grad_norm": 3.3700480461120605,
|
855 |
+
"learning_rate": 1.5180156657963446e-05,
|
856 |
+
"loss": 0.17,
|
857 |
+
"step": 1160
|
858 |
+
},
|
859 |
+
{
|
860 |
+
"epoch": 4.383098591549296,
|
861 |
+
"grad_norm": 2.552915573120117,
|
862 |
+
"learning_rate": 1.5023498694516973e-05,
|
863 |
+
"loss": 0.1802,
|
864 |
+
"step": 1170
|
865 |
+
},
|
866 |
+
{
|
867 |
+
"epoch": 4.420657276995305,
|
868 |
+
"grad_norm": 3.3317511081695557,
|
869 |
+
"learning_rate": 1.4866840731070497e-05,
|
870 |
+
"loss": 0.1933,
|
871 |
+
"step": 1180
|
872 |
+
},
|
873 |
+
{
|
874 |
+
"epoch": 4.458215962441314,
|
875 |
+
"grad_norm": 1.9266548156738281,
|
876 |
+
"learning_rate": 1.4710182767624021e-05,
|
877 |
+
"loss": 0.1739,
|
878 |
+
"step": 1190
|
879 |
+
},
|
880 |
+
{
|
881 |
+
"epoch": 4.495774647887324,
|
882 |
+
"grad_norm": 2.1459243297576904,
|
883 |
+
"learning_rate": 1.4553524804177547e-05,
|
884 |
+
"loss": 0.1599,
|
885 |
+
"step": 1200
|
886 |
+
},
|
887 |
+
{
|
888 |
+
"epoch": 4.533333333333333,
|
889 |
+
"grad_norm": 3.9314770698547363,
|
890 |
+
"learning_rate": 1.4396866840731071e-05,
|
891 |
+
"loss": 0.1958,
|
892 |
+
"step": 1210
|
893 |
+
},
|
894 |
+
{
|
895 |
+
"epoch": 4.570892018779343,
|
896 |
+
"grad_norm": 2.6377363204956055,
|
897 |
+
"learning_rate": 1.4240208877284597e-05,
|
898 |
+
"loss": 0.1604,
|
899 |
+
"step": 1220
|
900 |
+
},
|
901 |
+
{
|
902 |
+
"epoch": 4.608450704225352,
|
903 |
+
"grad_norm": 2.810866594314575,
|
904 |
+
"learning_rate": 1.408355091383812e-05,
|
905 |
+
"loss": 0.1495,
|
906 |
+
"step": 1230
|
907 |
+
},
|
908 |
+
{
|
909 |
+
"epoch": 4.646009389671361,
|
910 |
+
"grad_norm": 2.2084455490112305,
|
911 |
+
"learning_rate": 1.3926892950391646e-05,
|
912 |
+
"loss": 0.185,
|
913 |
+
"step": 1240
|
914 |
+
},
|
915 |
+
{
|
916 |
+
"epoch": 4.683568075117371,
|
917 |
+
"grad_norm": 2.7217283248901367,
|
918 |
+
"learning_rate": 1.377023498694517e-05,
|
919 |
+
"loss": 0.1757,
|
920 |
+
"step": 1250
|
921 |
+
},
|
922 |
+
{
|
923 |
+
"epoch": 4.72112676056338,
|
924 |
+
"grad_norm": 3.075267791748047,
|
925 |
+
"learning_rate": 1.3613577023498696e-05,
|
926 |
+
"loss": 0.1814,
|
927 |
+
"step": 1260
|
928 |
+
},
|
929 |
+
{
|
930 |
+
"epoch": 4.758685446009389,
|
931 |
+
"grad_norm": 3.2452406883239746,
|
932 |
+
"learning_rate": 1.345691906005222e-05,
|
933 |
+
"loss": 0.1622,
|
934 |
+
"step": 1270
|
935 |
+
},
|
936 |
+
{
|
937 |
+
"epoch": 4.796244131455399,
|
938 |
+
"grad_norm": 2.712754487991333,
|
939 |
+
"learning_rate": 1.3300261096605744e-05,
|
940 |
+
"loss": 0.1714,
|
941 |
+
"step": 1280
|
942 |
+
},
|
943 |
+
{
|
944 |
+
"epoch": 4.833802816901408,
|
945 |
+
"grad_norm": 1.6795600652694702,
|
946 |
+
"learning_rate": 1.3143603133159269e-05,
|
947 |
+
"loss": 0.1519,
|
948 |
+
"step": 1290
|
949 |
+
},
|
950 |
+
{
|
951 |
+
"epoch": 4.871361502347418,
|
952 |
+
"grad_norm": 3.9085493087768555,
|
953 |
+
"learning_rate": 1.2986945169712793e-05,
|
954 |
+
"loss": 0.1758,
|
955 |
+
"step": 1300
|
956 |
+
},
|
957 |
+
{
|
958 |
+
"epoch": 4.908920187793427,
|
959 |
+
"grad_norm": 3.529478073120117,
|
960 |
+
"learning_rate": 1.2830287206266318e-05,
|
961 |
+
"loss": 0.1549,
|
962 |
+
"step": 1310
|
963 |
+
},
|
964 |
+
{
|
965 |
+
"epoch": 4.946478873239436,
|
966 |
+
"grad_norm": 2.559157609939575,
|
967 |
+
"learning_rate": 1.2673629242819842e-05,
|
968 |
+
"loss": 0.1824,
|
969 |
+
"step": 1320
|
970 |
+
},
|
971 |
+
{
|
972 |
+
"epoch": 4.984037558685446,
|
973 |
+
"grad_norm": 2.2350497245788574,
|
974 |
+
"learning_rate": 1.2516971279373368e-05,
|
975 |
+
"loss": 0.1723,
|
976 |
+
"step": 1330
|
977 |
+
},
|
978 |
+
{
|
979 |
+
"epoch": 5.0,
|
980 |
+
"eval_accuracy": 0.9826419535157399,
|
981 |
+
"eval_loss": 0.09542840719223022,
|
982 |
+
"eval_runtime": 5.0389,
|
983 |
+
"eval_samples_per_second": 1349.105,
|
984 |
+
"eval_steps_per_second": 42.271,
|
985 |
+
"step": 1335
|
986 |
+
},
|
987 |
+
{
|
988 |
+
"epoch": 5.018779342723005,
|
989 |
+
"grad_norm": 2.5073907375335693,
|
990 |
+
"learning_rate": 1.2360313315926892e-05,
|
991 |
+
"loss": 0.1401,
|
992 |
+
"step": 1340
|
993 |
+
},
|
994 |
+
{
|
995 |
+
"epoch": 5.056338028169014,
|
996 |
+
"grad_norm": 4.696757793426514,
|
997 |
+
"learning_rate": 1.2203655352480418e-05,
|
998 |
+
"loss": 0.1801,
|
999 |
+
"step": 1350
|
1000 |
+
},
|
1001 |
+
{
|
1002 |
+
"epoch": 5.093896713615023,
|
1003 |
+
"grad_norm": 1.2180489301681519,
|
1004 |
+
"learning_rate": 1.2046997389033942e-05,
|
1005 |
+
"loss": 0.1335,
|
1006 |
+
"step": 1360
|
1007 |
+
},
|
1008 |
+
{
|
1009 |
+
"epoch": 5.131455399061033,
|
1010 |
+
"grad_norm": 0.887860119342804,
|
1011 |
+
"learning_rate": 1.1890339425587468e-05,
|
1012 |
+
"loss": 0.1479,
|
1013 |
+
"step": 1370
|
1014 |
+
},
|
1015 |
+
{
|
1016 |
+
"epoch": 5.169014084507042,
|
1017 |
+
"grad_norm": 3.6347432136535645,
|
1018 |
+
"learning_rate": 1.1733681462140992e-05,
|
1019 |
+
"loss": 0.1575,
|
1020 |
+
"step": 1380
|
1021 |
+
},
|
1022 |
+
{
|
1023 |
+
"epoch": 5.206572769953052,
|
1024 |
+
"grad_norm": 2.901700496673584,
|
1025 |
+
"learning_rate": 1.1577023498694518e-05,
|
1026 |
+
"loss": 0.1367,
|
1027 |
+
"step": 1390
|
1028 |
+
},
|
1029 |
+
{
|
1030 |
+
"epoch": 5.244131455399061,
|
1031 |
+
"grad_norm": 2.6395390033721924,
|
1032 |
+
"learning_rate": 1.1420365535248042e-05,
|
1033 |
+
"loss": 0.144,
|
1034 |
+
"step": 1400
|
1035 |
+
},
|
1036 |
+
{
|
1037 |
+
"epoch": 5.28169014084507,
|
1038 |
+
"grad_norm": 3.923652172088623,
|
1039 |
+
"learning_rate": 1.1263707571801567e-05,
|
1040 |
+
"loss": 0.1576,
|
1041 |
+
"step": 1410
|
1042 |
+
},
|
1043 |
+
{
|
1044 |
+
"epoch": 5.31924882629108,
|
1045 |
+
"grad_norm": 2.290224313735962,
|
1046 |
+
"learning_rate": 1.1107049608355092e-05,
|
1047 |
+
"loss": 0.16,
|
1048 |
+
"step": 1420
|
1049 |
+
},
|
1050 |
+
{
|
1051 |
+
"epoch": 5.356807511737089,
|
1052 |
+
"grad_norm": 2.332317590713501,
|
1053 |
+
"learning_rate": 1.0950391644908617e-05,
|
1054 |
+
"loss": 0.1505,
|
1055 |
+
"step": 1430
|
1056 |
+
},
|
1057 |
+
{
|
1058 |
+
"epoch": 5.394366197183099,
|
1059 |
+
"grad_norm": 3.474155902862549,
|
1060 |
+
"learning_rate": 1.0793733681462141e-05,
|
1061 |
+
"loss": 0.1828,
|
1062 |
+
"step": 1440
|
1063 |
+
},
|
1064 |
+
{
|
1065 |
+
"epoch": 5.431924882629108,
|
1066 |
+
"grad_norm": 2.5219180583953857,
|
1067 |
+
"learning_rate": 1.0637075718015665e-05,
|
1068 |
+
"loss": 0.1563,
|
1069 |
+
"step": 1450
|
1070 |
+
},
|
1071 |
+
{
|
1072 |
+
"epoch": 5.469483568075117,
|
1073 |
+
"grad_norm": 4.863851547241211,
|
1074 |
+
"learning_rate": 1.0480417754569191e-05,
|
1075 |
+
"loss": 0.1308,
|
1076 |
+
"step": 1460
|
1077 |
+
},
|
1078 |
+
{
|
1079 |
+
"epoch": 5.507042253521127,
|
1080 |
+
"grad_norm": 4.817688941955566,
|
1081 |
+
"learning_rate": 1.0323759791122715e-05,
|
1082 |
+
"loss": 0.1757,
|
1083 |
+
"step": 1470
|
1084 |
+
},
|
1085 |
+
{
|
1086 |
+
"epoch": 5.544600938967136,
|
1087 |
+
"grad_norm": 3.194732189178467,
|
1088 |
+
"learning_rate": 1.0167101827676241e-05,
|
1089 |
+
"loss": 0.1577,
|
1090 |
+
"step": 1480
|
1091 |
+
},
|
1092 |
+
{
|
1093 |
+
"epoch": 5.582159624413146,
|
1094 |
+
"grad_norm": 3.6605474948883057,
|
1095 |
+
"learning_rate": 1.0010443864229765e-05,
|
1096 |
+
"loss": 0.2044,
|
1097 |
+
"step": 1490
|
1098 |
+
},
|
1099 |
+
{
|
1100 |
+
"epoch": 5.619718309859155,
|
1101 |
+
"grad_norm": 2.427701473236084,
|
1102 |
+
"learning_rate": 9.853785900783291e-06,
|
1103 |
+
"loss": 0.1574,
|
1104 |
+
"step": 1500
|
1105 |
+
},
|
1106 |
+
{
|
1107 |
+
"epoch": 5.657276995305164,
|
1108 |
+
"grad_norm": 2.8025519847869873,
|
1109 |
+
"learning_rate": 9.697127937336815e-06,
|
1110 |
+
"loss": 0.188,
|
1111 |
+
"step": 1510
|
1112 |
+
},
|
1113 |
+
{
|
1114 |
+
"epoch": 5.694835680751174,
|
1115 |
+
"grad_norm": 2.042407989501953,
|
1116 |
+
"learning_rate": 9.54046997389034e-06,
|
1117 |
+
"loss": 0.1639,
|
1118 |
+
"step": 1520
|
1119 |
+
},
|
1120 |
+
{
|
1121 |
+
"epoch": 5.732394366197183,
|
1122 |
+
"grad_norm": 4.5383477210998535,
|
1123 |
+
"learning_rate": 9.383812010443865e-06,
|
1124 |
+
"loss": 0.1641,
|
1125 |
+
"step": 1530
|
1126 |
+
},
|
1127 |
+
{
|
1128 |
+
"epoch": 5.769953051643192,
|
1129 |
+
"grad_norm": 2.919588804244995,
|
1130 |
+
"learning_rate": 9.22715404699739e-06,
|
1131 |
+
"loss": 0.1374,
|
1132 |
+
"step": 1540
|
1133 |
+
},
|
1134 |
+
{
|
1135 |
+
"epoch": 5.807511737089202,
|
1136 |
+
"grad_norm": 2.4344029426574707,
|
1137 |
+
"learning_rate": 9.070496083550915e-06,
|
1138 |
+
"loss": 0.1711,
|
1139 |
+
"step": 1550
|
1140 |
+
},
|
1141 |
+
{
|
1142 |
+
"epoch": 5.845070422535211,
|
1143 |
+
"grad_norm": 1.5614906549453735,
|
1144 |
+
"learning_rate": 8.913838120104439e-06,
|
1145 |
+
"loss": 0.1624,
|
1146 |
+
"step": 1560
|
1147 |
+
},
|
1148 |
+
{
|
1149 |
+
"epoch": 5.882629107981221,
|
1150 |
+
"grad_norm": 3.0189967155456543,
|
1151 |
+
"learning_rate": 8.757180156657963e-06,
|
1152 |
+
"loss": 0.1691,
|
1153 |
+
"step": 1570
|
1154 |
+
},
|
1155 |
+
{
|
1156 |
+
"epoch": 5.92018779342723,
|
1157 |
+
"grad_norm": 2.44000506401062,
|
1158 |
+
"learning_rate": 8.600522193211488e-06,
|
1159 |
+
"loss": 0.1513,
|
1160 |
+
"step": 1580
|
1161 |
+
},
|
1162 |
+
{
|
1163 |
+
"epoch": 5.957746478873239,
|
1164 |
+
"grad_norm": 2.4327423572540283,
|
1165 |
+
"learning_rate": 8.443864229765013e-06,
|
1166 |
+
"loss": 0.1538,
|
1167 |
+
"step": 1590
|
1168 |
+
},
|
1169 |
+
{
|
1170 |
+
"epoch": 5.995305164319249,
|
1171 |
+
"grad_norm": 2.1192240715026855,
|
1172 |
+
"learning_rate": 8.287206266318538e-06,
|
1173 |
+
"loss": 0.1442,
|
1174 |
+
"step": 1600
|
1175 |
+
},
|
1176 |
+
{
|
1177 |
+
"epoch": 6.0,
|
1178 |
+
"eval_accuracy": 0.981318034716093,
|
1179 |
+
"eval_loss": 0.09270217269659042,
|
1180 |
+
"eval_runtime": 4.8524,
|
1181 |
+
"eval_samples_per_second": 1400.961,
|
1182 |
+
"eval_steps_per_second": 43.896,
|
1183 |
+
"step": 1602
|
1184 |
+
},
|
1185 |
+
{
|
1186 |
+
"epoch": 6.030046948356808,
|
1187 |
+
"grad_norm": 1.8678548336029053,
|
1188 |
+
"learning_rate": 8.130548302872062e-06,
|
1189 |
+
"loss": 0.1328,
|
1190 |
+
"step": 1610
|
1191 |
+
},
|
1192 |
+
{
|
1193 |
+
"epoch": 6.067605633802817,
|
1194 |
+
"grad_norm": 3.0712783336639404,
|
1195 |
+
"learning_rate": 7.973890339425586e-06,
|
1196 |
+
"loss": 0.1543,
|
1197 |
+
"step": 1620
|
1198 |
+
},
|
1199 |
+
{
|
1200 |
+
"epoch": 6.105164319248826,
|
1201 |
+
"grad_norm": 4.49588680267334,
|
1202 |
+
"learning_rate": 7.817232375979112e-06,
|
1203 |
+
"loss": 0.1452,
|
1204 |
+
"step": 1630
|
1205 |
+
},
|
1206 |
+
{
|
1207 |
+
"epoch": 6.142723004694836,
|
1208 |
+
"grad_norm": 3.9594759941101074,
|
1209 |
+
"learning_rate": 7.660574412532636e-06,
|
1210 |
+
"loss": 0.1513,
|
1211 |
+
"step": 1640
|
1212 |
+
},
|
1213 |
+
{
|
1214 |
+
"epoch": 6.180281690140845,
|
1215 |
+
"grad_norm": 2.528153657913208,
|
1216 |
+
"learning_rate": 7.503916449086162e-06,
|
1217 |
+
"loss": 0.1589,
|
1218 |
+
"step": 1650
|
1219 |
+
},
|
1220 |
+
{
|
1221 |
+
"epoch": 6.217840375586855,
|
1222 |
+
"grad_norm": 2.159458875656128,
|
1223 |
+
"learning_rate": 7.347258485639687e-06,
|
1224 |
+
"loss": 0.1443,
|
1225 |
+
"step": 1660
|
1226 |
+
},
|
1227 |
+
{
|
1228 |
+
"epoch": 6.255399061032864,
|
1229 |
+
"grad_norm": 2.098022222518921,
|
1230 |
+
"learning_rate": 7.190600522193212e-06,
|
1231 |
+
"loss": 0.1564,
|
1232 |
+
"step": 1670
|
1233 |
+
},
|
1234 |
+
{
|
1235 |
+
"epoch": 6.292957746478873,
|
1236 |
+
"grad_norm": 1.993698239326477,
|
1237 |
+
"learning_rate": 7.033942558746737e-06,
|
1238 |
+
"loss": 0.1401,
|
1239 |
+
"step": 1680
|
1240 |
+
},
|
1241 |
+
{
|
1242 |
+
"epoch": 6.330516431924883,
|
1243 |
+
"grad_norm": 2.2639145851135254,
|
1244 |
+
"learning_rate": 6.877284595300262e-06,
|
1245 |
+
"loss": 0.1452,
|
1246 |
+
"step": 1690
|
1247 |
+
},
|
1248 |
+
{
|
1249 |
+
"epoch": 6.368075117370892,
|
1250 |
+
"grad_norm": 2.5003936290740967,
|
1251 |
+
"learning_rate": 6.720626631853786e-06,
|
1252 |
+
"loss": 0.1439,
|
1253 |
+
"step": 1700
|
1254 |
+
},
|
1255 |
+
{
|
1256 |
+
"epoch": 6.405633802816902,
|
1257 |
+
"grad_norm": 2.0841052532196045,
|
1258 |
+
"learning_rate": 6.563968668407311e-06,
|
1259 |
+
"loss": 0.1438,
|
1260 |
+
"step": 1710
|
1261 |
+
},
|
1262 |
+
{
|
1263 |
+
"epoch": 6.443192488262911,
|
1264 |
+
"grad_norm": 3.550182819366455,
|
1265 |
+
"learning_rate": 6.4073107049608355e-06,
|
1266 |
+
"loss": 0.1433,
|
1267 |
+
"step": 1720
|
1268 |
+
},
|
1269 |
+
{
|
1270 |
+
"epoch": 6.48075117370892,
|
1271 |
+
"grad_norm": 1.4857251644134521,
|
1272 |
+
"learning_rate": 6.2506527415143605e-06,
|
1273 |
+
"loss": 0.1404,
|
1274 |
+
"step": 1730
|
1275 |
+
},
|
1276 |
+
{
|
1277 |
+
"epoch": 6.51830985915493,
|
1278 |
+
"grad_norm": 3.503309726715088,
|
1279 |
+
"learning_rate": 6.093994778067885e-06,
|
1280 |
+
"loss": 0.1493,
|
1281 |
+
"step": 1740
|
1282 |
+
},
|
1283 |
+
{
|
1284 |
+
"epoch": 6.555868544600939,
|
1285 |
+
"grad_norm": 3.59545636177063,
|
1286 |
+
"learning_rate": 5.93733681462141e-06,
|
1287 |
+
"loss": 0.1563,
|
1288 |
+
"step": 1750
|
1289 |
+
},
|
1290 |
+
{
|
1291 |
+
"epoch": 6.593427230046949,
|
1292 |
+
"grad_norm": 2.879582405090332,
|
1293 |
+
"learning_rate": 5.780678851174934e-06,
|
1294 |
+
"loss": 0.122,
|
1295 |
+
"step": 1760
|
1296 |
+
},
|
1297 |
+
{
|
1298 |
+
"epoch": 6.630985915492958,
|
1299 |
+
"grad_norm": 1.7240543365478516,
|
1300 |
+
"learning_rate": 5.624020887728459e-06,
|
1301 |
+
"loss": 0.1404,
|
1302 |
+
"step": 1770
|
1303 |
+
},
|
1304 |
+
{
|
1305 |
+
"epoch": 6.668544600938967,
|
1306 |
+
"grad_norm": 3.0438528060913086,
|
1307 |
+
"learning_rate": 5.467362924281984e-06,
|
1308 |
+
"loss": 0.1432,
|
1309 |
+
"step": 1780
|
1310 |
+
},
|
1311 |
+
{
|
1312 |
+
"epoch": 6.706103286384977,
|
1313 |
+
"grad_norm": 2.496366024017334,
|
1314 |
+
"learning_rate": 5.310704960835509e-06,
|
1315 |
+
"loss": 0.1277,
|
1316 |
+
"step": 1790
|
1317 |
+
},
|
1318 |
+
{
|
1319 |
+
"epoch": 6.743661971830986,
|
1320 |
+
"grad_norm": 1.7166277170181274,
|
1321 |
+
"learning_rate": 5.154046997389034e-06,
|
1322 |
+
"loss": 0.143,
|
1323 |
+
"step": 1800
|
1324 |
+
},
|
1325 |
+
{
|
1326 |
+
"epoch": 6.781220657276995,
|
1327 |
+
"grad_norm": 2.4547784328460693,
|
1328 |
+
"learning_rate": 4.997389033942559e-06,
|
1329 |
+
"loss": 0.1198,
|
1330 |
+
"step": 1810
|
1331 |
+
},
|
1332 |
+
{
|
1333 |
+
"epoch": 6.818779342723005,
|
1334 |
+
"grad_norm": 2.604220390319824,
|
1335 |
+
"learning_rate": 4.840731070496084e-06,
|
1336 |
+
"loss": 0.1705,
|
1337 |
+
"step": 1820
|
1338 |
+
},
|
1339 |
+
{
|
1340 |
+
"epoch": 6.856338028169014,
|
1341 |
+
"grad_norm": 2.7237601280212402,
|
1342 |
+
"learning_rate": 4.684073107049609e-06,
|
1343 |
+
"loss": 0.1506,
|
1344 |
+
"step": 1830
|
1345 |
+
},
|
1346 |
+
{
|
1347 |
+
"epoch": 6.893896713615024,
|
1348 |
+
"grad_norm": 2.638058662414551,
|
1349 |
+
"learning_rate": 4.527415143603134e-06,
|
1350 |
+
"loss": 0.154,
|
1351 |
+
"step": 1840
|
1352 |
+
},
|
1353 |
+
{
|
1354 |
+
"epoch": 6.931455399061033,
|
1355 |
+
"grad_norm": 3.8382205963134766,
|
1356 |
+
"learning_rate": 4.3707571801566586e-06,
|
1357 |
+
"loss": 0.1553,
|
1358 |
+
"step": 1850
|
1359 |
+
},
|
1360 |
+
{
|
1361 |
+
"epoch": 6.969014084507043,
|
1362 |
+
"grad_norm": 2.071164131164551,
|
1363 |
+
"learning_rate": 4.2140992167101835e-06,
|
1364 |
+
"loss": 0.1397,
|
1365 |
+
"step": 1860
|
1366 |
+
},
|
1367 |
+
{
|
1368 |
+
"epoch": 7.0,
|
1369 |
+
"eval_accuracy": 0.9811709326272433,
|
1370 |
+
"eval_loss": 0.08920056372880936,
|
1371 |
+
"eval_runtime": 4.9166,
|
1372 |
+
"eval_samples_per_second": 1382.662,
|
1373 |
+
"eval_steps_per_second": 43.323,
|
1374 |
+
"step": 1869
|
1375 |
+
},
|
1376 |
+
{
|
1377 |
+
"epoch": 7.003755868544601,
|
1378 |
+
"grad_norm": 2.5346381664276123,
|
1379 |
+
"learning_rate": 4.0574412532637075e-06,
|
1380 |
+
"loss": 0.1296,
|
1381 |
+
"step": 1870
|
1382 |
+
},
|
1383 |
+
{
|
1384 |
+
"epoch": 7.041314553990611,
|
1385 |
+
"grad_norm": 2.575307846069336,
|
1386 |
+
"learning_rate": 3.9007832898172325e-06,
|
1387 |
+
"loss": 0.1389,
|
1388 |
+
"step": 1880
|
1389 |
+
},
|
1390 |
+
{
|
1391 |
+
"epoch": 7.07887323943662,
|
1392 |
+
"grad_norm": 2.0408527851104736,
|
1393 |
+
"learning_rate": 3.7441253263707574e-06,
|
1394 |
+
"loss": 0.1521,
|
1395 |
+
"step": 1890
|
1396 |
+
},
|
1397 |
+
{
|
1398 |
+
"epoch": 7.1164319248826295,
|
1399 |
+
"grad_norm": 3.2742061614990234,
|
1400 |
+
"learning_rate": 3.5874673629242823e-06,
|
1401 |
+
"loss": 0.1342,
|
1402 |
+
"step": 1900
|
1403 |
+
},
|
1404 |
+
{
|
1405 |
+
"epoch": 7.153990610328639,
|
1406 |
+
"grad_norm": 1.4502960443496704,
|
1407 |
+
"learning_rate": 3.4308093994778068e-06,
|
1408 |
+
"loss": 0.1204,
|
1409 |
+
"step": 1910
|
1410 |
+
},
|
1411 |
+
{
|
1412 |
+
"epoch": 7.191549295774648,
|
1413 |
+
"grad_norm": 3.7600743770599365,
|
1414 |
+
"learning_rate": 3.2741514360313317e-06,
|
1415 |
+
"loss": 0.1431,
|
1416 |
+
"step": 1920
|
1417 |
+
},
|
1418 |
+
{
|
1419 |
+
"epoch": 7.229107981220658,
|
1420 |
+
"grad_norm": 2.7332417964935303,
|
1421 |
+
"learning_rate": 3.1174934725848566e-06,
|
1422 |
+
"loss": 0.1281,
|
1423 |
+
"step": 1930
|
1424 |
+
},
|
1425 |
+
{
|
1426 |
+
"epoch": 7.266666666666667,
|
1427 |
+
"grad_norm": 2.6618921756744385,
|
1428 |
+
"learning_rate": 2.960835509138381e-06,
|
1429 |
+
"loss": 0.141,
|
1430 |
+
"step": 1940
|
1431 |
+
},
|
1432 |
+
{
|
1433 |
+
"epoch": 7.304225352112676,
|
1434 |
+
"grad_norm": 3.625688314437866,
|
1435 |
+
"learning_rate": 2.804177545691906e-06,
|
1436 |
+
"loss": 0.1455,
|
1437 |
+
"step": 1950
|
1438 |
+
},
|
1439 |
+
{
|
1440 |
+
"epoch": 7.341784037558686,
|
1441 |
+
"grad_norm": 2.0667765140533447,
|
1442 |
+
"learning_rate": 2.647519582245431e-06,
|
1443 |
+
"loss": 0.1359,
|
1444 |
+
"step": 1960
|
1445 |
+
},
|
1446 |
+
{
|
1447 |
+
"epoch": 7.379342723004695,
|
1448 |
+
"grad_norm": 2.369652509689331,
|
1449 |
+
"learning_rate": 2.490861618798956e-06,
|
1450 |
+
"loss": 0.1295,
|
1451 |
+
"step": 1970
|
1452 |
+
},
|
1453 |
+
{
|
1454 |
+
"epoch": 7.416901408450705,
|
1455 |
+
"grad_norm": 3.836838722229004,
|
1456 |
+
"learning_rate": 2.3342036553524807e-06,
|
1457 |
+
"loss": 0.1489,
|
1458 |
+
"step": 1980
|
1459 |
+
},
|
1460 |
+
{
|
1461 |
+
"epoch": 7.454460093896714,
|
1462 |
+
"grad_norm": 3.3261311054229736,
|
1463 |
+
"learning_rate": 2.1775456919060052e-06,
|
1464 |
+
"loss": 0.1289,
|
1465 |
+
"step": 1990
|
1466 |
+
},
|
1467 |
+
{
|
1468 |
+
"epoch": 7.492018779342723,
|
1469 |
+
"grad_norm": 2.6514954566955566,
|
1470 |
+
"learning_rate": 2.0208877284595297e-06,
|
1471 |
+
"loss": 0.1185,
|
1472 |
+
"step": 2000
|
1473 |
+
},
|
1474 |
+
{
|
1475 |
+
"epoch": 7.529577464788733,
|
1476 |
+
"grad_norm": 2.1017005443573,
|
1477 |
+
"learning_rate": 1.8642297650130548e-06,
|
1478 |
+
"loss": 0.1472,
|
1479 |
+
"step": 2010
|
1480 |
+
},
|
1481 |
+
{
|
1482 |
+
"epoch": 7.567136150234742,
|
1483 |
+
"grad_norm": 2.5104258060455322,
|
1484 |
+
"learning_rate": 1.7075718015665795e-06,
|
1485 |
+
"loss": 0.1467,
|
1486 |
+
"step": 2020
|
1487 |
+
},
|
1488 |
+
{
|
1489 |
+
"epoch": 7.6046948356807516,
|
1490 |
+
"grad_norm": 1.7915935516357422,
|
1491 |
+
"learning_rate": 1.5509138381201045e-06,
|
1492 |
+
"loss": 0.1212,
|
1493 |
+
"step": 2030
|
1494 |
+
},
|
1495 |
+
{
|
1496 |
+
"epoch": 7.642253521126761,
|
1497 |
+
"grad_norm": 2.4937989711761475,
|
1498 |
+
"learning_rate": 1.3942558746736294e-06,
|
1499 |
+
"loss": 0.1395,
|
1500 |
+
"step": 2040
|
1501 |
+
},
|
1502 |
+
{
|
1503 |
+
"epoch": 7.67981220657277,
|
1504 |
+
"grad_norm": 2.758594274520874,
|
1505 |
+
"learning_rate": 1.237597911227154e-06,
|
1506 |
+
"loss": 0.1361,
|
1507 |
+
"step": 2050
|
1508 |
+
},
|
1509 |
+
{
|
1510 |
+
"epoch": 7.71737089201878,
|
1511 |
+
"grad_norm": 2.291672468185425,
|
1512 |
+
"learning_rate": 1.0809399477806788e-06,
|
1513 |
+
"loss": 0.1182,
|
1514 |
+
"step": 2060
|
1515 |
+
},
|
1516 |
+
{
|
1517 |
+
"epoch": 7.754929577464789,
|
1518 |
+
"grad_norm": 1.944736361503601,
|
1519 |
+
"learning_rate": 9.242819843342037e-07,
|
1520 |
+
"loss": 0.1307,
|
1521 |
+
"step": 2070
|
1522 |
+
},
|
1523 |
+
{
|
1524 |
+
"epoch": 7.792488262910798,
|
1525 |
+
"grad_norm": 1.448411226272583,
|
1526 |
+
"learning_rate": 7.676240208877285e-07,
|
1527 |
+
"loss": 0.1407,
|
1528 |
+
"step": 2080
|
1529 |
+
},
|
1530 |
+
{
|
1531 |
+
"epoch": 7.830046948356808,
|
1532 |
+
"grad_norm": 3.276000499725342,
|
1533 |
+
"learning_rate": 6.109660574412533e-07,
|
1534 |
+
"loss": 0.1361,
|
1535 |
+
"step": 2090
|
1536 |
+
},
|
1537 |
+
{
|
1538 |
+
"epoch": 7.867605633802817,
|
1539 |
+
"grad_norm": 3.627788543701172,
|
1540 |
+
"learning_rate": 4.5430809399477806e-07,
|
1541 |
+
"loss": 0.131,
|
1542 |
+
"step": 2100
|
1543 |
+
},
|
1544 |
+
{
|
1545 |
+
"epoch": 7.905164319248827,
|
1546 |
+
"grad_norm": 1.2533661127090454,
|
1547 |
+
"learning_rate": 2.9765013054830287e-07,
|
1548 |
+
"loss": 0.1245,
|
1549 |
+
"step": 2110
|
1550 |
+
},
|
1551 |
+
{
|
1552 |
+
"epoch": 7.942723004694836,
|
1553 |
+
"grad_norm": 1.472484827041626,
|
1554 |
+
"learning_rate": 1.409921671018277e-07,
|
1555 |
+
"loss": 0.1368,
|
1556 |
+
"step": 2120
|
1557 |
+
},
|
1558 |
+
{
|
1559 |
+
"epoch": 7.972769953051643,
|
1560 |
+
"eval_accuracy": 0.9811709326272433,
|
1561 |
+
"eval_loss": 0.08957477658987045,
|
1562 |
+
"eval_runtime": 5.418,
|
1563 |
+
"eval_samples_per_second": 1254.697,
|
1564 |
+
"eval_steps_per_second": 39.313,
|
1565 |
+
"step": 2128
|
1566 |
+
},
|
1567 |
+
{
|
1568 |
+
"epoch": 7.972769953051643,
|
1569 |
+
"step": 2128,
|
1570 |
+
"total_flos": 3.767900833756416e+18,
|
1571 |
+
"train_loss": 0.5178930132572812,
|
1572 |
+
"train_runtime": 756.2923,
|
1573 |
+
"train_samples_per_second": 540.468,
|
1574 |
+
"train_steps_per_second": 2.814
|
1575 |
+
}
|
1576 |
+
],
|
1577 |
+
"logging_steps": 10,
|
1578 |
+
"max_steps": 2128,
|
1579 |
+
"num_input_tokens_seen": 0,
|
1580 |
+
"num_train_epochs": 8,
|
1581 |
+
"save_steps": 500,
|
1582 |
+
"stateful_callbacks": {
|
1583 |
+
"TrainerControl": {
|
1584 |
+
"args": {
|
1585 |
+
"should_epoch_stop": false,
|
1586 |
+
"should_evaluate": false,
|
1587 |
+
"should_log": false,
|
1588 |
+
"should_save": true,
|
1589 |
+
"should_training_stop": true
|
1590 |
+
},
|
1591 |
+
"attributes": {}
|
1592 |
+
}
|
1593 |
+
},
|
1594 |
+
"total_flos": 3.767900833756416e+18,
|
1595 |
+
"train_batch_size": 48,
|
1596 |
+
"trial_name": null,
|
1597 |
+
"trial_params": null
|
1598 |
+
}
|
hpo-examples/audio-classification/ac/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:11e71a8faf8833e9bfb138217263fb0518314e3b8597902b752b2bc9dd143942
|
3 |
+
size 5368
|
hpo-examples/audio-classification/requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets>=1.14.0
|
2 |
+
evaluate
|
3 |
+
librosa
|
4 |
+
torchaudio
|
5 |
+
torch>=1.6
|
hpo-examples/audio-classification/run.sh
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES=0 python run_audio_classification.py \
|
2 |
+
--model_name_or_path facebook/wav2vec2-base \
|
3 |
+
--dataset_name superb \
|
4 |
+
--dataset_config_name ks \
|
5 |
+
--trust_remote_code \
|
6 |
+
--output_dir wav2vec2-base-ft-keyword-spotting \
|
7 |
+
--overwrite_output_dir \
|
8 |
+
--remove_unused_columns False \
|
9 |
+
--do_train \
|
10 |
+
--do_eval \
|
11 |
+
--fp16 \
|
12 |
+
--learning_rate 3e-5 \
|
13 |
+
--max_length_seconds 1 \
|
14 |
+
--attention_mask False \
|
15 |
+
--warmup_ratio 0.1 \
|
16 |
+
--num_train_epochs 8 \
|
17 |
+
--per_device_train_batch_size 64 \
|
18 |
+
--gradient_accumulation_steps 4 \
|
19 |
+
--per_device_eval_batch_size 32 \
|
20 |
+
--dataloader_num_workers 4 \
|
21 |
+
--logging_strategy steps \
|
22 |
+
--logging_steps 10 \
|
23 |
+
--eval_strategy epoch \
|
24 |
+
--save_strategy epoch \
|
25 |
+
--load_best_model_at_end True \
|
26 |
+
--metric_for_best_model accuracy \
|
27 |
+
--save_total_limit 3 \
|
28 |
+
--seed 0 \
|
29 |
+
--push_to_hub \
|
30 |
+
--apply-trp --trp-depths 1 --trp-p 0.1 --trp-lambdas 0.4 0.2 0.1
|
hpo-examples/audio-classification/run_audio_classification.py
ADDED
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import logging
|
18 |
+
import os
|
19 |
+
import sys
|
20 |
+
import warnings
|
21 |
+
from dataclasses import dataclass, field
|
22 |
+
from random import randint
|
23 |
+
from typing import Optional, List
|
24 |
+
|
25 |
+
import datasets
|
26 |
+
import evaluate
|
27 |
+
import numpy as np
|
28 |
+
from datasets import DatasetDict, load_dataset
|
29 |
+
|
30 |
+
import transformers
|
31 |
+
from transformers import (
|
32 |
+
AutoConfig,
|
33 |
+
AutoFeatureExtractor,
|
34 |
+
AutoModelForAudioClassification,
|
35 |
+
HfArgumentParser,
|
36 |
+
Trainer,
|
37 |
+
TrainingArguments,
|
38 |
+
set_seed,
|
39 |
+
)
|
40 |
+
from transformers.trainer_utils import get_last_checkpoint
|
41 |
+
from transformers.utils import check_min_version, send_example_telemetry
|
42 |
+
from transformers.utils.versions import require_version
|
43 |
+
|
44 |
+
|
45 |
+
from trplib import apply_trp
|
46 |
+
|
47 |
+
|
48 |
+
logger = logging.getLogger(__name__)
|
49 |
+
|
50 |
+
# # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
51 |
+
# check_min_version("4.50.0.dev0")
|
52 |
+
|
53 |
+
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
54 |
+
|
55 |
+
|
56 |
+
def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 16000):
|
57 |
+
"""Randomly sample chunks of `max_length` seconds from the input audio"""
|
58 |
+
sample_length = int(round(sample_rate * max_length))
|
59 |
+
if len(wav) <= sample_length:
|
60 |
+
return wav
|
61 |
+
random_offset = randint(0, len(wav) - sample_length - 1)
|
62 |
+
return wav[random_offset : random_offset + sample_length]
|
63 |
+
|
64 |
+
|
65 |
+
@dataclass
|
66 |
+
class DataTrainingArguments:
|
67 |
+
"""
|
68 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
69 |
+
Using `HfArgumentParser` we can turn this class
|
70 |
+
into argparse arguments to be able to specify them on
|
71 |
+
the command line.
|
72 |
+
"""
|
73 |
+
|
74 |
+
dataset_name: Optional[str] = field(default=None, metadata={"help": "Name of a dataset from the datasets package"})
|
75 |
+
dataset_config_name: Optional[str] = field(
|
76 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
77 |
+
)
|
78 |
+
train_file: Optional[str] = field(
|
79 |
+
default=None, metadata={"help": "A file containing the training audio paths and labels."}
|
80 |
+
)
|
81 |
+
eval_file: Optional[str] = field(
|
82 |
+
default=None, metadata={"help": "A file containing the validation audio paths and labels."}
|
83 |
+
)
|
84 |
+
train_split_name: str = field(
|
85 |
+
default="train",
|
86 |
+
metadata={
|
87 |
+
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
|
88 |
+
},
|
89 |
+
)
|
90 |
+
eval_split_name: str = field(
|
91 |
+
default="validation",
|
92 |
+
metadata={
|
93 |
+
"help": (
|
94 |
+
"The name of the training data set split to use (via the datasets library). Defaults to 'validation'"
|
95 |
+
)
|
96 |
+
},
|
97 |
+
)
|
98 |
+
audio_column_name: str = field(
|
99 |
+
default="audio",
|
100 |
+
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
|
101 |
+
)
|
102 |
+
label_column_name: str = field(
|
103 |
+
default="label", metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'"}
|
104 |
+
)
|
105 |
+
max_train_samples: Optional[int] = field(
|
106 |
+
default=None,
|
107 |
+
metadata={
|
108 |
+
"help": (
|
109 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
110 |
+
"value if set."
|
111 |
+
)
|
112 |
+
},
|
113 |
+
)
|
114 |
+
max_eval_samples: Optional[int] = field(
|
115 |
+
default=None,
|
116 |
+
metadata={
|
117 |
+
"help": (
|
118 |
+
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
119 |
+
"value if set."
|
120 |
+
)
|
121 |
+
},
|
122 |
+
)
|
123 |
+
max_length_seconds: float = field(
|
124 |
+
default=20,
|
125 |
+
metadata={"help": "Audio clips will be randomly cut to this length during training if the value is set."},
|
126 |
+
)
|
127 |
+
|
128 |
+
|
129 |
+
@dataclass
|
130 |
+
class ModelArguments:
|
131 |
+
"""
|
132 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
133 |
+
"""
|
134 |
+
|
135 |
+
model_name_or_path: str = field(
|
136 |
+
default="facebook/wav2vec2-base",
|
137 |
+
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"},
|
138 |
+
)
|
139 |
+
config_name: Optional[str] = field(
|
140 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
141 |
+
)
|
142 |
+
cache_dir: Optional[str] = field(
|
143 |
+
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from the Hub"}
|
144 |
+
)
|
145 |
+
model_revision: str = field(
|
146 |
+
default="main",
|
147 |
+
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
148 |
+
)
|
149 |
+
feature_extractor_name: Optional[str] = field(
|
150 |
+
default=None, metadata={"help": "Name or path of preprocessor config."}
|
151 |
+
)
|
152 |
+
freeze_feature_encoder: bool = field(
|
153 |
+
default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
|
154 |
+
)
|
155 |
+
attention_mask: bool = field(
|
156 |
+
default=True, metadata={"help": "Whether to generate an attention mask in the feature extractor."}
|
157 |
+
)
|
158 |
+
token: str = field(
|
159 |
+
default=None,
|
160 |
+
metadata={
|
161 |
+
"help": (
|
162 |
+
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
|
163 |
+
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
|
164 |
+
)
|
165 |
+
},
|
166 |
+
)
|
167 |
+
trust_remote_code: bool = field(
|
168 |
+
default=False,
|
169 |
+
metadata={
|
170 |
+
"help": (
|
171 |
+
"Whether to trust the execution of code from datasets/models defined on the Hub."
|
172 |
+
" This option should only be set to `True` for repositories you trust and in which you have read the"
|
173 |
+
" code, as it will execute code present on the Hub on your local machine."
|
174 |
+
)
|
175 |
+
},
|
176 |
+
)
|
177 |
+
freeze_feature_extractor: Optional[bool] = field(
|
178 |
+
default=None, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
|
179 |
+
)
|
180 |
+
ignore_mismatched_sizes: bool = field(
|
181 |
+
default=False,
|
182 |
+
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
|
183 |
+
)
|
184 |
+
|
185 |
+
apply_trp: Optional[bool] = field(
|
186 |
+
default=False,
|
187 |
+
metadata={"help": "Whether to apply trp or not."},
|
188 |
+
)
|
189 |
+
trp_depths: Optional[int] = field(
|
190 |
+
default=1,
|
191 |
+
metadata={
|
192 |
+
"help": "TRP depth value."
|
193 |
+
},
|
194 |
+
)
|
195 |
+
trp_p: Optional[float] = field(
|
196 |
+
default=0.1,
|
197 |
+
metadata={
|
198 |
+
"help": "TRP p value."
|
199 |
+
},
|
200 |
+
)
|
201 |
+
trp_lambdas: Optional[List[float]] = field(
|
202 |
+
default_factory=lambda: [0.4, 0.2, 0.1],
|
203 |
+
metadata={
|
204 |
+
"help": "TRP lambda values (list of floats)."
|
205 |
+
},
|
206 |
+
)
|
207 |
+
|
208 |
+
def __post_init__(self):
|
209 |
+
if not self.freeze_feature_extractor and self.freeze_feature_encoder:
|
210 |
+
warnings.warn(
|
211 |
+
"The argument `--freeze_feature_extractor` is deprecated and "
|
212 |
+
"will be removed in a future version. Use `--freeze_feature_encoder` "
|
213 |
+
"instead. Setting `freeze_feature_encoder==True`.",
|
214 |
+
FutureWarning,
|
215 |
+
)
|
216 |
+
if self.freeze_feature_extractor and not self.freeze_feature_encoder:
|
217 |
+
raise ValueError(
|
218 |
+
"The argument `--freeze_feature_extractor` is deprecated and "
|
219 |
+
"should not be used in combination with `--freeze_feature_encoder`. "
|
220 |
+
"Only make use of `--freeze_feature_encoder`."
|
221 |
+
)
|
222 |
+
|
223 |
+
|
224 |
+
def main():
|
225 |
+
# See all possible arguments in src/transformers/training_args.py
|
226 |
+
# or by passing the --help flag to this script.
|
227 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
228 |
+
|
229 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
230 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
231 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
232 |
+
# let's parse it to get our arguments.
|
233 |
+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
234 |
+
else:
|
235 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
236 |
+
|
237 |
+
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
238 |
+
# information sent is the one passed as arguments along with your Python/PyTorch versions.
|
239 |
+
send_example_telemetry("run_audio_classification", model_args, data_args)
|
240 |
+
|
241 |
+
# Setup logging
|
242 |
+
logging.basicConfig(
|
243 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
244 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
245 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
246 |
+
)
|
247 |
+
|
248 |
+
if training_args.should_log:
|
249 |
+
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
250 |
+
transformers.utils.logging.set_verbosity_info()
|
251 |
+
|
252 |
+
log_level = training_args.get_process_log_level()
|
253 |
+
logger.setLevel(log_level)
|
254 |
+
transformers.utils.logging.set_verbosity(log_level)
|
255 |
+
transformers.utils.logging.enable_default_handler()
|
256 |
+
transformers.utils.logging.enable_explicit_format()
|
257 |
+
|
258 |
+
# Log on each process the small summary:
|
259 |
+
logger.warning(
|
260 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
|
261 |
+
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
|
262 |
+
)
|
263 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
264 |
+
|
265 |
+
# Set seed before initializing model.
|
266 |
+
set_seed(training_args.seed)
|
267 |
+
|
268 |
+
# Detecting last checkpoint.
|
269 |
+
last_checkpoint = None
|
270 |
+
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
271 |
+
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
272 |
+
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
273 |
+
raise ValueError(
|
274 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
275 |
+
"Use --overwrite_output_dir to train from scratch."
|
276 |
+
)
|
277 |
+
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
278 |
+
logger.info(
|
279 |
+
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
280 |
+
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
281 |
+
)
|
282 |
+
|
283 |
+
# Initialize our dataset and prepare it for the audio classification task.
|
284 |
+
raw_datasets = DatasetDict()
|
285 |
+
raw_datasets["train"] = load_dataset(
|
286 |
+
data_args.dataset_name,
|
287 |
+
data_args.dataset_config_name,
|
288 |
+
split=data_args.train_split_name,
|
289 |
+
token=model_args.token,
|
290 |
+
trust_remote_code=model_args.trust_remote_code,
|
291 |
+
)
|
292 |
+
raw_datasets["eval"] = load_dataset(
|
293 |
+
data_args.dataset_name,
|
294 |
+
data_args.dataset_config_name,
|
295 |
+
split=data_args.eval_split_name,
|
296 |
+
token=model_args.token,
|
297 |
+
trust_remote_code=model_args.trust_remote_code,
|
298 |
+
)
|
299 |
+
|
300 |
+
if data_args.audio_column_name not in raw_datasets["train"].column_names:
|
301 |
+
raise ValueError(
|
302 |
+
f"--audio_column_name {data_args.audio_column_name} not found in dataset '{data_args.dataset_name}'. "
|
303 |
+
"Make sure to set `--audio_column_name` to the correct audio column - one of "
|
304 |
+
f"{', '.join(raw_datasets['train'].column_names)}."
|
305 |
+
)
|
306 |
+
|
307 |
+
if data_args.label_column_name not in raw_datasets["train"].column_names:
|
308 |
+
raise ValueError(
|
309 |
+
f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. "
|
310 |
+
"Make sure to set `--label_column_name` to the correct text column - one of "
|
311 |
+
f"{', '.join(raw_datasets['train'].column_names)}."
|
312 |
+
)
|
313 |
+
|
314 |
+
# Setting `return_attention_mask=True` is the way to get a correctly masked mean-pooling over
|
315 |
+
# transformer outputs in the classifier, but it doesn't always lead to better accuracy
|
316 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
317 |
+
model_args.feature_extractor_name or model_args.model_name_or_path,
|
318 |
+
return_attention_mask=model_args.attention_mask,
|
319 |
+
cache_dir=model_args.cache_dir,
|
320 |
+
revision=model_args.model_revision,
|
321 |
+
token=model_args.token,
|
322 |
+
trust_remote_code=model_args.trust_remote_code,
|
323 |
+
)
|
324 |
+
|
325 |
+
# `datasets` takes care of automatically loading and resampling the audio,
|
326 |
+
# so we just need to set the correct target sampling rate.
|
327 |
+
raw_datasets = raw_datasets.cast_column(
|
328 |
+
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
329 |
+
)
|
330 |
+
|
331 |
+
model_input_name = feature_extractor.model_input_names[0]
|
332 |
+
|
333 |
+
def train_transforms(batch):
|
334 |
+
"""Apply train_transforms across a batch."""
|
335 |
+
subsampled_wavs = []
|
336 |
+
for audio in batch[data_args.audio_column_name]:
|
337 |
+
wav = random_subsample(
|
338 |
+
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
|
339 |
+
)
|
340 |
+
subsampled_wavs.append(wav)
|
341 |
+
inputs = feature_extractor(subsampled_wavs, sampling_rate=feature_extractor.sampling_rate)
|
342 |
+
output_batch = {model_input_name: inputs.get(model_input_name)}
|
343 |
+
output_batch["labels"] = list(batch[data_args.label_column_name])
|
344 |
+
|
345 |
+
return output_batch
|
346 |
+
|
347 |
+
def val_transforms(batch):
|
348 |
+
"""Apply val_transforms across a batch."""
|
349 |
+
wavs = [audio["array"] for audio in batch[data_args.audio_column_name]]
|
350 |
+
inputs = feature_extractor(wavs, sampling_rate=feature_extractor.sampling_rate)
|
351 |
+
output_batch = {model_input_name: inputs.get(model_input_name)}
|
352 |
+
output_batch["labels"] = list(batch[data_args.label_column_name])
|
353 |
+
|
354 |
+
return output_batch
|
355 |
+
|
356 |
+
# Prepare label mappings.
|
357 |
+
# We'll include these in the model's config to get human readable labels in the Inference API.
|
358 |
+
labels = raw_datasets["train"].features[data_args.label_column_name].names
|
359 |
+
label2id, id2label = {}, {}
|
360 |
+
for i, label in enumerate(labels):
|
361 |
+
label2id[label] = str(i)
|
362 |
+
id2label[str(i)] = label
|
363 |
+
|
364 |
+
# Load the accuracy metric from the datasets package
|
365 |
+
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
|
366 |
+
|
367 |
+
# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with
|
368 |
+
# `predictions` and `label_ids` fields) and has to return a dictionary string to float.
|
369 |
+
def compute_metrics(eval_pred):
|
370 |
+
"""Computes accuracy on a batch of predictions"""
|
371 |
+
predictions = np.argmax(eval_pred.predictions, axis=1)
|
372 |
+
return metric.compute(predictions=predictions, references=eval_pred.label_ids)
|
373 |
+
|
374 |
+
config = AutoConfig.from_pretrained(
|
375 |
+
model_args.config_name or model_args.model_name_or_path,
|
376 |
+
num_labels=len(labels),
|
377 |
+
label2id=label2id,
|
378 |
+
id2label=id2label,
|
379 |
+
finetuning_task="audio-classification",
|
380 |
+
cache_dir=model_args.cache_dir,
|
381 |
+
revision=model_args.model_revision,
|
382 |
+
token=model_args.token,
|
383 |
+
trust_remote_code=model_args.trust_remote_code,
|
384 |
+
)
|
385 |
+
model = AutoModelForAudioClassification.from_pretrained(
|
386 |
+
model_args.model_name_or_path,
|
387 |
+
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
388 |
+
config=config,
|
389 |
+
cache_dir=model_args.cache_dir,
|
390 |
+
revision=model_args.model_revision,
|
391 |
+
token=model_args.token,
|
392 |
+
trust_remote_code=model_args.trust_remote_code,
|
393 |
+
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
|
394 |
+
)
|
395 |
+
|
396 |
+
# freeze the convolutional waveform encoder
|
397 |
+
if model_args.freeze_feature_encoder:
|
398 |
+
model.freeze_feature_encoder()
|
399 |
+
|
400 |
+
if model_args.apply_trp:
|
401 |
+
model = apply_trp(model, depths=model_args.trp_depths, p=model_args.trp_p, lambdas=model_args.trp_lambdas)
|
402 |
+
|
403 |
+
if training_args.do_train:
|
404 |
+
if data_args.max_train_samples is not None:
|
405 |
+
raw_datasets["train"] = (
|
406 |
+
raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
|
407 |
+
)
|
408 |
+
# Set the training transforms
|
409 |
+
raw_datasets["train"].set_transform(train_transforms, output_all_columns=False)
|
410 |
+
|
411 |
+
if training_args.do_eval:
|
412 |
+
if data_args.max_eval_samples is not None:
|
413 |
+
raw_datasets["eval"] = (
|
414 |
+
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
|
415 |
+
)
|
416 |
+
# Set the validation transforms
|
417 |
+
raw_datasets["eval"].set_transform(val_transforms, output_all_columns=False)
|
418 |
+
|
419 |
+
# Initialize our trainer
|
420 |
+
trainer = Trainer(
|
421 |
+
model=model,
|
422 |
+
args=training_args,
|
423 |
+
train_dataset=raw_datasets["train"] if training_args.do_train else None,
|
424 |
+
eval_dataset=raw_datasets["eval"] if training_args.do_eval else None,
|
425 |
+
compute_metrics=compute_metrics,
|
426 |
+
processing_class=feature_extractor,
|
427 |
+
)
|
428 |
+
|
429 |
+
# Training
|
430 |
+
if training_args.do_train:
|
431 |
+
checkpoint = None
|
432 |
+
if training_args.resume_from_checkpoint is not None:
|
433 |
+
checkpoint = training_args.resume_from_checkpoint
|
434 |
+
elif last_checkpoint is not None:
|
435 |
+
checkpoint = last_checkpoint
|
436 |
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
437 |
+
trainer.save_model()
|
438 |
+
trainer.log_metrics("train", train_result.metrics)
|
439 |
+
trainer.save_metrics("train", train_result.metrics)
|
440 |
+
trainer.save_state()
|
441 |
+
|
442 |
+
# Evaluation
|
443 |
+
if training_args.do_eval:
|
444 |
+
metrics = trainer.evaluate()
|
445 |
+
trainer.log_metrics("eval", metrics)
|
446 |
+
trainer.save_metrics("eval", metrics)
|
447 |
+
|
448 |
+
# Write model card and (optionally) push to hub
|
449 |
+
kwargs = {
|
450 |
+
"finetuned_from": model_args.model_name_or_path,
|
451 |
+
"tasks": "audio-classification",
|
452 |
+
"dataset": data_args.dataset_name,
|
453 |
+
"tags": ["audio-classification"],
|
454 |
+
}
|
455 |
+
if training_args.push_to_hub:
|
456 |
+
trainer.push_to_hub(**kwargs)
|
457 |
+
else:
|
458 |
+
trainer.create_model_card(**kwargs)
|
459 |
+
|
460 |
+
|
461 |
+
if __name__ == "__main__":
|
462 |
+
main()
|
hpo-examples/audio-classification/trplib.py
ADDED
@@ -0,0 +1,1181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, Tensor
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from torchvision.models.mobilenetv2 import MobileNetV2
|
6 |
+
from torchvision.models.resnet import ResNet
|
7 |
+
from torchvision.models.efficientnet import EfficientNet
|
8 |
+
from torchvision.models.vision_transformer import VisionTransformer
|
9 |
+
from torchvision.models.segmentation.fcn import FCN
|
10 |
+
from torchvision.models.segmentation.deeplabv3 import DeepLabV3
|
11 |
+
|
12 |
+
import transformers
|
13 |
+
from transformers.modeling_outputs import SequenceClassifierOutput, QuestionAnsweringModelOutput, CausalLMOutput, Seq2SeqLMOutput
|
14 |
+
|
15 |
+
from typing import Optional, Tuple, List, Union, Callable
|
16 |
+
from collections import OrderedDict
|
17 |
+
import types
|
18 |
+
|
19 |
+
|
20 |
+
def trp_criterion(trp_blocks: nn.ModuleList, shared_head: Callable, criterion: Callable, lambdas: List[float], hidden_states: Tensor, logits: Tensor, targets: Tensor, loss_normalization=False):
|
21 |
+
loss, mask = criterion(logits, targets)
|
22 |
+
if loss_normalization:
|
23 |
+
coeff = loss.detach()
|
24 |
+
|
25 |
+
embeds = [hidden_states]
|
26 |
+
predictions = []
|
27 |
+
for k, c in enumerate(lambdas):
|
28 |
+
embeds.append(trp_blocks[k](embeds[-1]))
|
29 |
+
predictions.append(shared_head(embeds[-1]))
|
30 |
+
replica_loss, mask = criterion(predictions[-1], targets, mask)
|
31 |
+
loss += c * replica_loss
|
32 |
+
|
33 |
+
if loss_normalization:
|
34 |
+
with torch.no_grad():
|
35 |
+
coeff = torch.exp(coeff) / torch.exp(loss.detach())
|
36 |
+
loss = coeff * loss
|
37 |
+
|
38 |
+
return loss
|
39 |
+
|
40 |
+
|
41 |
+
class TPBlock(nn.Module):
|
42 |
+
def __init__(self, depths: int, in_features: int, p: float, dim=-1):
|
43 |
+
super(TPBlock, self).__init__()
|
44 |
+
|
45 |
+
self.dropout = nn.Dropout(p)
|
46 |
+
|
47 |
+
self.cdim = dim
|
48 |
+
|
49 |
+
blocks = []
|
50 |
+
for _ in range(depths):
|
51 |
+
blocks.append(nn.Linear(in_features, in_features))
|
52 |
+
nn.init.constant_(blocks[-1].weight, 0.0)
|
53 |
+
nn.init.constant_(blocks[-1].bias, 0.0)
|
54 |
+
blocks.append(nn.ReLU())
|
55 |
+
self.blocks = nn.Sequential(*blocks)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
x = self.dropout(x)
|
59 |
+
if self.cdim == -1:
|
60 |
+
x = x + self.blocks(x)
|
61 |
+
else:
|
62 |
+
x = x + torch.movedim(self.blocks(torch.movedim(x, self.cdim, -1)), -1, self.cdim)
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
class Config:
|
67 |
+
@staticmethod
|
68 |
+
def gen_criterion(*args, **kwargs):
|
69 |
+
def func(input, target, mask=None):
|
70 |
+
"""
|
71 |
+
Args:
|
72 |
+
input (Tensor): Input tensor.
|
73 |
+
target (Tensor): Target labels.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
loss (Tensor): Scalar tensor representing the loss.
|
77 |
+
mask (Tensor): Boolean mask tensor with the same shape of target.
|
78 |
+
"""
|
79 |
+
pass
|
80 |
+
return func
|
81 |
+
|
82 |
+
@staticmethod
|
83 |
+
def gen_shared_head(*args, **kwargs):
|
84 |
+
def func(hidden_states):
|
85 |
+
"""
|
86 |
+
Args:
|
87 |
+
hidden_states (Tensor): Hidden States tensor.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
logits (Tensor): Logits tensor.
|
91 |
+
"""
|
92 |
+
pass
|
93 |
+
return func
|
94 |
+
|
95 |
+
@staticmethod
|
96 |
+
def forward(*args, **kwargs):
|
97 |
+
pass
|
98 |
+
|
99 |
+
|
100 |
+
# Wav2Vec2 for Audio Classification
|
101 |
+
class Wav2Vec2ForSequenceClassificationConfig(Config):
|
102 |
+
_HIDDEN_STATES_START_POSITION = 2
|
103 |
+
|
104 |
+
@staticmethod
|
105 |
+
def gen_criterion():
|
106 |
+
def func(input, target, mask=None):
|
107 |
+
"""
|
108 |
+
Args:
|
109 |
+
input (Tensor): Input tensor of shape [B, C].
|
110 |
+
target (Tensor): Target labels of shape [B].
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
loss (Tensor): Scalar tensor representing the loss.
|
114 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
115 |
+
"""
|
116 |
+
if mask is None:
|
117 |
+
mask = torch.ones_like(target, dtype=torch.float32, device=target.device)
|
118 |
+
|
119 |
+
unmasked_loss = F.cross_entropy(input, target, reduction="none")
|
120 |
+
loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
|
121 |
+
|
122 |
+
with torch.no_grad():
|
123 |
+
mask = mask * torch.eq(torch.argmax(input, dim=1), target).to(input.dtype)
|
124 |
+
|
125 |
+
return loss, mask
|
126 |
+
return func
|
127 |
+
|
128 |
+
@staticmethod
|
129 |
+
def gen_shared_head(self, attention_mask):
|
130 |
+
def func(hidden_states):
|
131 |
+
"""
|
132 |
+
Args:
|
133 |
+
hidden_states (Tensor): Hidden States of shape [B, L, hidden_units].
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
137 |
+
"""
|
138 |
+
_hidden_states = self.projector(hidden_states)
|
139 |
+
if attention_mask is None:
|
140 |
+
pooled_output = _hidden_states.mean(dim=1)
|
141 |
+
else:
|
142 |
+
padding_mask = self._get_feature_vector_attention_mask(_hidden_states.shape[1], attention_mask)
|
143 |
+
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, _hidden_states.shape[2])
|
144 |
+
_hidden_states[~expand_padding_mask] = 0.0
|
145 |
+
pooled_output = _hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
146 |
+
|
147 |
+
logits = self.classifier(pooled_output)
|
148 |
+
return logits
|
149 |
+
return func
|
150 |
+
|
151 |
+
@staticmethod
|
152 |
+
def gen_forward(lambdas, loss_normalization=False):
|
153 |
+
def func(
|
154 |
+
self,
|
155 |
+
input_values: Optional[torch.Tensor],
|
156 |
+
attention_mask: Optional[torch.Tensor] = None,
|
157 |
+
output_attentions: Optional[bool] = None,
|
158 |
+
output_hidden_states: Optional[bool] = None,
|
159 |
+
return_dict: Optional[bool] = None,
|
160 |
+
labels: Optional[torch.Tensor] = None,
|
161 |
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
162 |
+
r"""
|
163 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
164 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
165 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
166 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
167 |
+
"""
|
168 |
+
|
169 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
170 |
+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
171 |
+
|
172 |
+
outputs = self.wav2vec2(
|
173 |
+
input_values,
|
174 |
+
attention_mask=attention_mask,
|
175 |
+
output_attentions=output_attentions,
|
176 |
+
output_hidden_states=output_hidden_states,
|
177 |
+
return_dict=return_dict,
|
178 |
+
)
|
179 |
+
|
180 |
+
if self.config.use_weighted_layer_sum:
|
181 |
+
hidden_states = outputs[Wav2Vec2ForSequenceClassificationConfig._HIDDEN_STATES_START_POSITION]
|
182 |
+
hidden_states = torch.stack(hidden_states, dim=1)
|
183 |
+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
184 |
+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
185 |
+
else:
|
186 |
+
hidden_states = outputs[0]
|
187 |
+
|
188 |
+
_hidden_states = self.projector(hidden_states)
|
189 |
+
if attention_mask is None:
|
190 |
+
pooled_output = _hidden_states.mean(dim=1)
|
191 |
+
else:
|
192 |
+
padding_mask = self._get_feature_vector_attention_mask(_hidden_states.shape[1], attention_mask)
|
193 |
+
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, _hidden_states.shape[2])
|
194 |
+
_hidden_states[~expand_padding_mask] = 0.0
|
195 |
+
pooled_output = _hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
196 |
+
|
197 |
+
logits = self.classifier(pooled_output)
|
198 |
+
|
199 |
+
loss = None
|
200 |
+
if labels is not None:
|
201 |
+
shared_head = Wav2Vec2ForSequenceClassificationConfig.gen_shared_head(self, attention_mask)
|
202 |
+
criterion = Wav2Vec2ForSequenceClassificationConfig.gen_criterion()
|
203 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_states, logits.view(-1, self.config.num_labels), labels.view(-1), loss_normalization) # NOTE: Apply TRP!
|
204 |
+
|
205 |
+
if not return_dict:
|
206 |
+
output = (logits,) + outputs[Wav2Vec2ForSequenceClassificationConfig._HIDDEN_STATES_START_POSITION:]
|
207 |
+
return ((loss,) + output) if loss is not None else output
|
208 |
+
|
209 |
+
return SequenceClassifierOutput(
|
210 |
+
loss=loss,
|
211 |
+
logits=logits,
|
212 |
+
hidden_states=outputs.hidden_states,
|
213 |
+
attentions=outputs.attentions,
|
214 |
+
)
|
215 |
+
return func
|
216 |
+
|
217 |
+
|
218 |
+
# MobileNetV2 for Image Classification
|
219 |
+
class MobileNetV2Config(Config):
|
220 |
+
@staticmethod
|
221 |
+
def gen_criterion(label_smoothing=0.0, top_k=1):
|
222 |
+
def func(input, target, mask=None):
|
223 |
+
"""
|
224 |
+
Args:
|
225 |
+
input (Tensor): Input tensor of shape [B, C].
|
226 |
+
target (Tensor): Target labels of shape [B] or [B, C].
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
loss (Tensor): Scalar tensor representing the loss.
|
230 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
231 |
+
"""
|
232 |
+
label = torch.argmax(target, dim=1) if label_smoothing > 0.0 else target
|
233 |
+
|
234 |
+
unmasked_loss = F.cross_entropy(input, label, reduction="none", label_smoothing=label_smoothing)
|
235 |
+
if mask is None:
|
236 |
+
mask = torch.ones_like(unmasked_loss, dtype=torch.float32, device=target.device)
|
237 |
+
loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
|
238 |
+
|
239 |
+
with torch.no_grad():
|
240 |
+
topk_values, topk_indices = torch.topk(input, top_k, dim=-1)
|
241 |
+
mask = mask * torch.eq(topk_indices, label[:, None]).any(dim=-1).to(input.dtype)
|
242 |
+
|
243 |
+
return loss, mask
|
244 |
+
return func
|
245 |
+
|
246 |
+
@staticmethod
|
247 |
+
def gen_shared_head(self):
|
248 |
+
def func(x):
|
249 |
+
"""
|
250 |
+
Args:
|
251 |
+
x (Tensor): Hidden States tensor of shape [B, hidden_units].
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
255 |
+
"""
|
256 |
+
logits = self.classifier(x)
|
257 |
+
return logits
|
258 |
+
return func
|
259 |
+
|
260 |
+
@staticmethod
|
261 |
+
def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
|
262 |
+
def func(self, images: Tensor, targets=None):
|
263 |
+
x = self.features(images)
|
264 |
+
x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
|
265 |
+
x = torch.flatten(x, 1)
|
266 |
+
logits = self.classifier(x)
|
267 |
+
|
268 |
+
if self.training:
|
269 |
+
torch._assert(targets is not None, "targets should not be none when in training mode")
|
270 |
+
shared_head = MobileNetV2Config.gen_shared_head(self)
|
271 |
+
criterion = MobileNetV2Config.gen_criterion(label_smoothing, top_k)
|
272 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
|
273 |
+
return logits, loss
|
274 |
+
return logits
|
275 |
+
return func
|
276 |
+
|
277 |
+
|
278 |
+
# ResNet for Image Classification
|
279 |
+
class ResNetConfig(MobileNetV2Config):
|
280 |
+
@staticmethod
|
281 |
+
def gen_shared_head(self):
|
282 |
+
def func(x):
|
283 |
+
"""
|
284 |
+
Args:
|
285 |
+
x (Tensor): Hidden States tensor of shape [B, hidden_units].
|
286 |
+
|
287 |
+
Returns:
|
288 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
289 |
+
"""
|
290 |
+
logits = self.fc(x)
|
291 |
+
return logits
|
292 |
+
return func
|
293 |
+
|
294 |
+
@staticmethod
|
295 |
+
def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
|
296 |
+
def func(self, images: Tensor, targets=None):
|
297 |
+
x = self.conv1(images)
|
298 |
+
x = self.bn1(x)
|
299 |
+
x = self.relu(x)
|
300 |
+
x = self.maxpool(x)
|
301 |
+
|
302 |
+
x = self.layer1(x)
|
303 |
+
x = self.layer2(x)
|
304 |
+
x = self.layer3(x)
|
305 |
+
x = self.layer4(x)
|
306 |
+
|
307 |
+
x = self.avgpool(x)
|
308 |
+
x = torch.flatten(x, 1)
|
309 |
+
logits = self.fc(x)
|
310 |
+
|
311 |
+
if self.training:
|
312 |
+
torch._assert(targets is not None, "targets should not be none when in training mode")
|
313 |
+
shared_head = ResNetConfig.gen_shared_head(self)
|
314 |
+
criterion = ResNetConfig.gen_criterion(label_smoothing, top_k)
|
315 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
|
316 |
+
return logits, loss
|
317 |
+
return logits
|
318 |
+
return func
|
319 |
+
|
320 |
+
|
321 |
+
# EfficientNet for Image Classification
|
322 |
+
class EfficientNetConfig(MobileNetV2Config):
|
323 |
+
@staticmethod
|
324 |
+
def gen_shared_head(self):
|
325 |
+
def func(x):
|
326 |
+
"""
|
327 |
+
Args:
|
328 |
+
x (Tensor): Hidden States tensor of shape [B, hidden_units].
|
329 |
+
|
330 |
+
Returns:
|
331 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
332 |
+
"""
|
333 |
+
logits = self.classifier(x)
|
334 |
+
return logits
|
335 |
+
return func
|
336 |
+
|
337 |
+
@staticmethod
|
338 |
+
def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
|
339 |
+
def func(self, images: Tensor, targets=None):
|
340 |
+
x = self.features(images)
|
341 |
+
x = self.avgpool(x)
|
342 |
+
x = torch.flatten(x, 1)
|
343 |
+
logits = self.classifier(x)
|
344 |
+
|
345 |
+
if self.training:
|
346 |
+
torch._assert(targets is not None, "targets should not be none when in training mode")
|
347 |
+
shared_head = EfficientNetConfig.gen_shared_head(self)
|
348 |
+
criterion = EfficientNetConfig.gen_criterion(label_smoothing, top_k)
|
349 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
|
350 |
+
return logits, loss
|
351 |
+
return logits
|
352 |
+
return func
|
353 |
+
|
354 |
+
|
355 |
+
# ViT for Image Classification
|
356 |
+
class VisionTransformerConfig(MobileNetV2Config):
|
357 |
+
@staticmethod
|
358 |
+
def gen_shared_head(self):
|
359 |
+
def func(x):
|
360 |
+
"""
|
361 |
+
Args:
|
362 |
+
x (Tensor): Hidden States tensor of shape [B, hidden_units].
|
363 |
+
|
364 |
+
Returns:
|
365 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
366 |
+
"""
|
367 |
+
logits = self.heads(x)
|
368 |
+
return logits
|
369 |
+
return func
|
370 |
+
|
371 |
+
@staticmethod
|
372 |
+
def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
|
373 |
+
def func(self, images: Tensor, targets=None):
|
374 |
+
x = self._process_input(images)
|
375 |
+
n = x.shape[0]
|
376 |
+
batch_class_token = self.class_token.expand(n, -1, -1)
|
377 |
+
x = torch.cat([batch_class_token, x], dim=1)
|
378 |
+
x = self.encoder(x)
|
379 |
+
x = x[:, 0]
|
380 |
+
|
381 |
+
logits = self.heads(x)
|
382 |
+
|
383 |
+
if self.training:
|
384 |
+
torch._assert(targets is not None, "targets should not be none when in training mode")
|
385 |
+
shared_head = VisionTransformerConfig.gen_shared_head(self)
|
386 |
+
criterion = VisionTransformerConfig.gen_criterion(label_smoothing, top_k)
|
387 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
|
388 |
+
return logits, loss
|
389 |
+
return logits
|
390 |
+
return func
|
391 |
+
|
392 |
+
|
393 |
+
# Bert for Question Answering
|
394 |
+
class BertForQuestionAnsweringConfig(Config):
|
395 |
+
@staticmethod
|
396 |
+
def gen_criterion(top_k=1):
|
397 |
+
def func(input, target: List[Tensor], mask=None):
|
398 |
+
"""
|
399 |
+
Args:
|
400 |
+
input (Tensor): Input tensor of shape [B, C, 2].
|
401 |
+
target (List[Tensor]):
|
402 |
+
Start Positions of shape [B].
|
403 |
+
End Positions of shape [B].
|
404 |
+
|
405 |
+
Returns:
|
406 |
+
loss (Tensor): Scalar tensor representing the loss.
|
407 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
408 |
+
"""
|
409 |
+
start_positions, end_positions = target
|
410 |
+
|
411 |
+
if mask is None:
|
412 |
+
mask = torch.ones_like(start_positions, dtype=torch.float32, device=start_positions.device)
|
413 |
+
|
414 |
+
start_logits, end_logits = input.split(1, dim=-1)
|
415 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
416 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
417 |
+
|
418 |
+
# If we are on multi-GPU, split add a dimension
|
419 |
+
if len(start_positions.size()) > 1:
|
420 |
+
start_positions = start_positions.squeeze(-1)
|
421 |
+
if len(end_positions.size()) > 1:
|
422 |
+
end_positions = end_positions.squeeze(-1)
|
423 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
424 |
+
ignored_index = start_logits.size(1)
|
425 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
426 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
427 |
+
|
428 |
+
masked_start_losses = F.cross_entropy(start_logits, start_positions, ignore_index=ignored_index, reduction="none")
|
429 |
+
start_loss = torch.sum(mask * masked_start_losses) / (torch.sum(mask) + 1e-6)
|
430 |
+
masked_end_losses = F.cross_entropy(end_logits, end_positions, ignore_index=ignored_index, reduction="none")
|
431 |
+
end_loss = torch.sum(mask * masked_end_losses) / (torch.sum(mask) + 1e-6)
|
432 |
+
|
433 |
+
with torch.no_grad():
|
434 |
+
topk_values, topk_indices = torch.topk(start_logits, top_k, dim=1)
|
435 |
+
mask = mask * torch.eq(topk_indices, start_positions[:, None]).any(dim=1).to(start_logits.dtype)
|
436 |
+
topk_values, topk_indices = torch.topk(end_logits, top_k, dim=1)
|
437 |
+
mask = mask * torch.eq(topk_indices, end_positions[:, None]).any(dim=1).to(end_logits.dtype)
|
438 |
+
|
439 |
+
return (start_loss + end_loss) / 2, mask
|
440 |
+
return func
|
441 |
+
|
442 |
+
@staticmethod
|
443 |
+
def gen_shared_head(self):
|
444 |
+
def func(hidden_states):
|
445 |
+
"""
|
446 |
+
Args:
|
447 |
+
hidden_states (Tensor): Hidden States of shape [B, C, hidden_units].
|
448 |
+
|
449 |
+
Returns:
|
450 |
+
logits (Tensor): Logits tensor of shape [B, C, 2].
|
451 |
+
"""
|
452 |
+
logits = self.qa_outputs(hidden_states)
|
453 |
+
return logits
|
454 |
+
return func
|
455 |
+
|
456 |
+
@staticmethod
|
457 |
+
def gen_forward(lambdas, loss_normalization=True, top_k=1):
|
458 |
+
def func(
|
459 |
+
self,
|
460 |
+
input_ids: Optional[torch.Tensor] = None,
|
461 |
+
attention_mask: Optional[torch.Tensor] = None,
|
462 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
463 |
+
position_ids: Optional[torch.Tensor] = None,
|
464 |
+
head_mask: Optional[torch.Tensor] = None,
|
465 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
466 |
+
start_positions: Optional[torch.Tensor] = None,
|
467 |
+
end_positions: Optional[torch.Tensor] = None,
|
468 |
+
output_attentions: Optional[bool] = None,
|
469 |
+
output_hidden_states: Optional[bool] = None,
|
470 |
+
return_dict: Optional[bool] = None,
|
471 |
+
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
472 |
+
r"""
|
473 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
474 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
475 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
476 |
+
are not taken into account for computing the loss.
|
477 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
478 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
479 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
480 |
+
are not taken into account for computing the loss.
|
481 |
+
"""
|
482 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
483 |
+
|
484 |
+
outputs = self.bert(
|
485 |
+
input_ids,
|
486 |
+
attention_mask=attention_mask,
|
487 |
+
token_type_ids=token_type_ids,
|
488 |
+
position_ids=position_ids,
|
489 |
+
head_mask=head_mask,
|
490 |
+
inputs_embeds=inputs_embeds,
|
491 |
+
output_attentions=output_attentions,
|
492 |
+
output_hidden_states=output_hidden_states,
|
493 |
+
return_dict=return_dict,
|
494 |
+
)
|
495 |
+
|
496 |
+
sequence_output = outputs[0]
|
497 |
+
|
498 |
+
logits = self.qa_outputs(sequence_output)
|
499 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
500 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
501 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
502 |
+
|
503 |
+
total_loss = None
|
504 |
+
if start_positions is not None and end_positions is not None:
|
505 |
+
shared_head = BertForQuestionAnsweringConfig.gen_shared_head(self)
|
506 |
+
criterion = BertForQuestionAnsweringConfig.gen_criterion()
|
507 |
+
total_loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, sequence_output, logits, [start_positions, end_positions], loss_normalization) # NOTE: Apply TRP!
|
508 |
+
|
509 |
+
if not return_dict:
|
510 |
+
output = (start_logits, end_logits) + outputs[2:]
|
511 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
512 |
+
|
513 |
+
return QuestionAnsweringModelOutput(
|
514 |
+
loss=total_loss,
|
515 |
+
start_logits=start_logits,
|
516 |
+
end_logits=end_logits,
|
517 |
+
hidden_states=outputs.hidden_states,
|
518 |
+
attentions=outputs.attentions,
|
519 |
+
)
|
520 |
+
return func
|
521 |
+
|
522 |
+
|
523 |
+
# FCN for Semantic Segmentation
|
524 |
+
class FCNConfig(Config):
|
525 |
+
@staticmethod
|
526 |
+
def gen_criterion(top_k=1):
|
527 |
+
def func(input, target, mask=None):
|
528 |
+
"""
|
529 |
+
Args:
|
530 |
+
input Tensor: input tensor of shape [B, C, H, W].
|
531 |
+
target (Tensor): Target labels of shape [B, H, W].
|
532 |
+
|
533 |
+
Returns:
|
534 |
+
loss (Tensor): Scalar tensor representing the loss.
|
535 |
+
mask (Tensor): Boolean mask tensor of shape [B, H, W].
|
536 |
+
"""
|
537 |
+
if mask is None:
|
538 |
+
mask = torch.ones_like(target, dtype=torch.float32, device=target.device)
|
539 |
+
|
540 |
+
masked_loss = F.cross_entropy(input, target, ignore_index=255, reduction="none")
|
541 |
+
loss = torch.sum(mask * masked_loss) / (torch.sum(mask) + 1e-6)
|
542 |
+
|
543 |
+
with torch.no_grad():
|
544 |
+
topk_values, topk_indices = torch.topk(input, top_k, dim=1)
|
545 |
+
mask = mask * torch.eq(topk_indices, target[:, None, :, :]).any(dim=1).to(input.dtype)
|
546 |
+
# mask = mask * torch.eq(torch.argmax(x, dim=1), target).to(x.dtype)
|
547 |
+
|
548 |
+
return loss, mask
|
549 |
+
return func
|
550 |
+
|
551 |
+
@staticmethod
|
552 |
+
def gen_out_shared_head(self, input_shape):
|
553 |
+
def func(features):
|
554 |
+
"""
|
555 |
+
Args:
|
556 |
+
features (Tensor): features tensor of shape [B, hidden_units, H, W].
|
557 |
+
|
558 |
+
Returns:
|
559 |
+
result (Tensors): result tensor of shape [B, C, H, W].
|
560 |
+
"""
|
561 |
+
x = self.classifier(features)
|
562 |
+
result = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
563 |
+
return result
|
564 |
+
return func
|
565 |
+
|
566 |
+
@staticmethod
|
567 |
+
def gen_aux_shared_head(self, input_shape):
|
568 |
+
def func(features):
|
569 |
+
"""
|
570 |
+
Args:
|
571 |
+
features (Tensor): features tensor of shape [B, hidden_units, H, W].
|
572 |
+
|
573 |
+
Returns:
|
574 |
+
result (Tensors): result tensor of shape [B, C, H, W].
|
575 |
+
"""
|
576 |
+
x = self.aux_classifier(features)
|
577 |
+
result = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
578 |
+
return result
|
579 |
+
return func
|
580 |
+
|
581 |
+
@staticmethod
|
582 |
+
def gen_forward(lambdas, loss_normalization=True, top_k=1):
|
583 |
+
def func(self, images: Tensor, targets=None):
|
584 |
+
input_shape = images.shape[-2:]
|
585 |
+
# contract: features is a dict of tensors
|
586 |
+
features = self.backbone(images)
|
587 |
+
|
588 |
+
result = OrderedDict()
|
589 |
+
x = features["out"]
|
590 |
+
x = self.classifier(x)
|
591 |
+
x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
592 |
+
result["out"] = x
|
593 |
+
|
594 |
+
if self.aux_classifier is not None:
|
595 |
+
x = features["aux"]
|
596 |
+
x = self.aux_classifier(x)
|
597 |
+
x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
598 |
+
result["aux"] = x
|
599 |
+
|
600 |
+
if self.training:
|
601 |
+
torch._assert(targets is not None, "targets should not be none when in training mode")
|
602 |
+
out_shared_head = FCNConfig.gen_out_shared_head(self, input_shape)
|
603 |
+
aux_shared_head = FCNConfig.gen_aux_shared_head(self, input_shape)
|
604 |
+
criterion = FCNConfig.gen_criterion(top_k)
|
605 |
+
out_loss = trp_criterion(self.out_trp_blocks, out_shared_head, criterion, lambdas, features["out"], result["out"], targets, loss_normalization)
|
606 |
+
aux_loss = trp_criterion(self.aux_trp_blocks, aux_shared_head, criterion, lambdas, features["aux"], result["aux"], targets, loss_normalization)
|
607 |
+
loss = out_loss + 0.5 * aux_loss
|
608 |
+
return result, loss
|
609 |
+
return result
|
610 |
+
return func
|
611 |
+
|
612 |
+
|
613 |
+
# DeepLabV3Config for Semantic Segmentation
|
614 |
+
class DeepLabV3Config(FCNConfig):
|
615 |
+
pass
|
616 |
+
|
617 |
+
|
618 |
+
# Bert for Text Classification
|
619 |
+
class BertForSequenceClassificationConfig(Config):
|
620 |
+
@staticmethod
|
621 |
+
def gen_criterion():
|
622 |
+
def func(input, target, mask=None):
|
623 |
+
"""
|
624 |
+
Args:
|
625 |
+
input (Tensor): Input tensor of shape [B, C].
|
626 |
+
target (Tensor): Target labels of shape [B].
|
627 |
+
|
628 |
+
Returns:
|
629 |
+
loss (Tensor): Scalar tensor representing the loss.
|
630 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
631 |
+
"""
|
632 |
+
if mask is None:
|
633 |
+
mask = torch.ones_like(target, dtype=torch.float32, device=target.device)
|
634 |
+
|
635 |
+
unmasked_loss = F.cross_entropy(input, target, reduction="none")
|
636 |
+
loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
|
637 |
+
|
638 |
+
with torch.no_grad():
|
639 |
+
mask = mask * torch.eq(torch.argmax(input, dim=1), target).to(input.dtype)
|
640 |
+
|
641 |
+
return loss, mask
|
642 |
+
return func
|
643 |
+
|
644 |
+
@staticmethod
|
645 |
+
def gen_shared_head(self):
|
646 |
+
def func(hidden_states):
|
647 |
+
"""
|
648 |
+
Args:
|
649 |
+
hidden_states (Tensor): Hidden States of shape [B, hidden_units].
|
650 |
+
|
651 |
+
Returns:
|
652 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
653 |
+
"""
|
654 |
+
logits = self.classifier(hidden_states)
|
655 |
+
return logits
|
656 |
+
return func
|
657 |
+
|
658 |
+
@staticmethod
|
659 |
+
def gen_forward(lambdas, loss_normalization=False):
|
660 |
+
def func(
|
661 |
+
self,
|
662 |
+
input_ids: Optional[torch.Tensor] = None,
|
663 |
+
attention_mask: Optional[torch.Tensor] = None,
|
664 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
665 |
+
position_ids: Optional[torch.Tensor] = None,
|
666 |
+
head_mask: Optional[torch.Tensor] = None,
|
667 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
668 |
+
labels: Optional[torch.Tensor] = None,
|
669 |
+
output_attentions: Optional[bool] = None,
|
670 |
+
output_hidden_states: Optional[bool] = None,
|
671 |
+
return_dict: Optional[bool] = None,
|
672 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
673 |
+
r"""
|
674 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
675 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
676 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
677 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
678 |
+
"""
|
679 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
680 |
+
|
681 |
+
outputs = self.bert(
|
682 |
+
input_ids,
|
683 |
+
attention_mask=attention_mask,
|
684 |
+
token_type_ids=token_type_ids,
|
685 |
+
position_ids=position_ids,
|
686 |
+
head_mask=head_mask,
|
687 |
+
inputs_embeds=inputs_embeds,
|
688 |
+
output_attentions=output_attentions,
|
689 |
+
output_hidden_states=output_hidden_states,
|
690 |
+
return_dict=return_dict,
|
691 |
+
)
|
692 |
+
|
693 |
+
pooled_output = outputs[1]
|
694 |
+
|
695 |
+
pooled_output = self.dropout(pooled_output)
|
696 |
+
logits = self.classifier(pooled_output)
|
697 |
+
|
698 |
+
loss = None
|
699 |
+
if labels is not None:
|
700 |
+
assert self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int) # TODO: remove this
|
701 |
+
if self.config.problem_type is None:
|
702 |
+
if self.num_labels == 1:
|
703 |
+
self.config.problem_type = "regression"
|
704 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
705 |
+
self.config.problem_type = "single_label_classification"
|
706 |
+
else:
|
707 |
+
self.config.problem_type = "multi_label_classification"
|
708 |
+
|
709 |
+
if self.config.problem_type == "regression":
|
710 |
+
if self.num_labels == 1:
|
711 |
+
loss = F.mse_loss(logits.squeeze(), labels.squeeze())
|
712 |
+
else:
|
713 |
+
loss = F.mse_loss(logits, labels)
|
714 |
+
elif self.config.problem_type == "single_label_classification":
|
715 |
+
shared_head = BertForSequenceClassificationConfig.gen_shared_head(self)
|
716 |
+
criterion = BertForSequenceClassificationConfig.gen_criterion()
|
717 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, pooled_output, logits, labels, loss_normalization)
|
718 |
+
elif self.config.problem_type == "multi_label_classification":
|
719 |
+
loss = F.binary_cross_entropy_with_logits(logits, labels)
|
720 |
+
if not return_dict:
|
721 |
+
output = (logits,) + outputs[2:]
|
722 |
+
return ((loss,) + output) if loss is not None else output
|
723 |
+
|
724 |
+
return SequenceClassifierOutput(
|
725 |
+
loss=loss,
|
726 |
+
logits=logits,
|
727 |
+
hidden_states=outputs.hidden_states,
|
728 |
+
attentions=outputs.attentions,
|
729 |
+
)
|
730 |
+
return func
|
731 |
+
|
732 |
+
|
733 |
+
# Boberta for Text Classification
|
734 |
+
class RobertaForSequenceClassificationConfig(BertForSequenceClassificationConfig):
|
735 |
+
@staticmethod
|
736 |
+
def gen_shared_head(self):
|
737 |
+
def func(hidden_states):
|
738 |
+
"""
|
739 |
+
Args:
|
740 |
+
hidden_states (Tensor): Hidden States of shape [B, hidden_units].
|
741 |
+
|
742 |
+
Returns:
|
743 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
744 |
+
"""
|
745 |
+
logits = self.classifier(hidden_states)
|
746 |
+
return logits
|
747 |
+
return func
|
748 |
+
|
749 |
+
@staticmethod
|
750 |
+
def gen_forward(lambdas, loss_normalization=False):
|
751 |
+
def func(
|
752 |
+
self,
|
753 |
+
input_ids: Optional[torch.LongTensor] = None,
|
754 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
755 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
756 |
+
position_ids: Optional[torch.LongTensor] = None,
|
757 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
758 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
759 |
+
labels: Optional[torch.LongTensor] = None,
|
760 |
+
output_attentions: Optional[bool] = None,
|
761 |
+
output_hidden_states: Optional[bool] = None,
|
762 |
+
return_dict: Optional[bool] = None,
|
763 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
764 |
+
r"""
|
765 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
766 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
767 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
768 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
769 |
+
"""
|
770 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
771 |
+
|
772 |
+
outputs = self.roberta(
|
773 |
+
input_ids,
|
774 |
+
attention_mask=attention_mask,
|
775 |
+
token_type_ids=token_type_ids,
|
776 |
+
position_ids=position_ids,
|
777 |
+
head_mask=head_mask,
|
778 |
+
inputs_embeds=inputs_embeds,
|
779 |
+
output_attentions=output_attentions,
|
780 |
+
output_hidden_states=output_hidden_states,
|
781 |
+
return_dict=return_dict,
|
782 |
+
)
|
783 |
+
sequence_output = outputs[0]
|
784 |
+
logits = self.classifier(sequence_output)
|
785 |
+
|
786 |
+
loss = None
|
787 |
+
if labels is not None:
|
788 |
+
assert self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int) # TODO: remove this
|
789 |
+
# move labels to correct device to enable model parallelism
|
790 |
+
labels = labels.to(logits.device)
|
791 |
+
if self.config.problem_type is None:
|
792 |
+
if self.num_labels == 1:
|
793 |
+
self.config.problem_type = "regression"
|
794 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
795 |
+
self.config.problem_type = "single_label_classification"
|
796 |
+
else:
|
797 |
+
self.config.problem_type = "multi_label_classification"
|
798 |
+
|
799 |
+
if self.config.problem_type == "regression":
|
800 |
+
if self.num_labels == 1:
|
801 |
+
loss = F.mse_loss(logits.squeeze(), labels.squeeze())
|
802 |
+
else:
|
803 |
+
loss = F.mse_loss(logits, labels)
|
804 |
+
elif self.config.problem_type == "single_label_classification":
|
805 |
+
shared_head = BertForSequenceClassificationConfig.gen_shared_head(self)
|
806 |
+
criterion = BertForSequenceClassificationConfig.gen_criterion()
|
807 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, sequence_output, logits, labels, loss_normalization)
|
808 |
+
elif self.config.problem_type == "multi_label_classification":
|
809 |
+
loss = F.binary_cross_entropy_with_logits(logits, labels)
|
810 |
+
|
811 |
+
if not return_dict:
|
812 |
+
output = (logits,) + outputs[2:]
|
813 |
+
return ((loss,) + output) if loss is not None else output
|
814 |
+
|
815 |
+
return SequenceClassifierOutput(
|
816 |
+
loss=loss,
|
817 |
+
logits=logits,
|
818 |
+
hidden_states=outputs.hidden_states,
|
819 |
+
attentions=outputs.attentions,
|
820 |
+
)
|
821 |
+
return func
|
822 |
+
|
823 |
+
|
824 |
+
# Wav2Vec2 for Speech Recognition
|
825 |
+
class Wav2Vec2ForCTCConfig(Config):
|
826 |
+
_HIDDEN_STATES_START_POSITION = 2
|
827 |
+
|
828 |
+
@staticmethod
|
829 |
+
def greedy_decode_ctc(
|
830 |
+
log_probs: torch.Tensor,
|
831 |
+
input_lengths: torch.Tensor,
|
832 |
+
blank_token_id: int,
|
833 |
+
target_lengths: torch.Tensor
|
834 |
+
):
|
835 |
+
"""
|
836 |
+
Convert logits to flattened predictions that match the shape of flattened_targets.
|
837 |
+
|
838 |
+
Args:
|
839 |
+
log_probs: [B, L, V] - log-softmax output
|
840 |
+
input_lengths: [B] - actual length of each input
|
841 |
+
blank_token_id: int - index of blank token
|
842 |
+
target_lengths: [B] - used to determine how many predictions to keep per sample
|
843 |
+
|
844 |
+
Returns:
|
845 |
+
flattened_predictions: 1D tensor, same total length as sum(target_lengths)
|
846 |
+
"""
|
847 |
+
batch_size = log_probs.size(0)
|
848 |
+
decoded_all = []
|
849 |
+
|
850 |
+
predicted_ids = log_probs.argmax(dim=-1) # [B, L]
|
851 |
+
|
852 |
+
for i in range(batch_size):
|
853 |
+
pred = predicted_ids[i][:input_lengths[i]] # [Li]
|
854 |
+
prev = None
|
855 |
+
decoded = []
|
856 |
+
for token in pred:
|
857 |
+
token = token.item()
|
858 |
+
if token != blank_token_id and token != prev:
|
859 |
+
decoded.append(token)
|
860 |
+
prev = token
|
861 |
+
# Trim or pad to match target_lengths[i]
|
862 |
+
tgt_len = target_lengths[i].item()
|
863 |
+
if len(decoded) >= tgt_len:
|
864 |
+
decoded = decoded[:tgt_len]
|
865 |
+
else:
|
866 |
+
decoded = decoded + [blank_token_id] * (tgt_len - len(decoded)) # pad with blank
|
867 |
+
decoded_all.extend(decoded)
|
868 |
+
|
869 |
+
return torch.tensor(decoded_all, dtype=torch.long, device=log_probs.device) # shape: [sum(target_lengths)]
|
870 |
+
|
871 |
+
@staticmethod
|
872 |
+
def gen_criterion(input_lengths: Tensor, pad_token_id: int, ctc_zero_infinity: bool):
|
873 |
+
def func(logits: Tensor, labels: Tensor, mask=None):
|
874 |
+
"""
|
875 |
+
Args:
|
876 |
+
logits (Tensor): Log Probablities of shape [B, L, V].
|
877 |
+
labels (Tensor): Flattened Targets of shape [B, L'].
|
878 |
+
|
879 |
+
Returns:
|
880 |
+
loss (Tensor): Scalar tensor representing the loss.
|
881 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
882 |
+
"""
|
883 |
+
if mask is None:
|
884 |
+
mask = torch.ones_like(input_lengths, dtype=torch.float32, device=input_lengths.device)
|
885 |
+
|
886 |
+
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
887 |
+
labels_mask = labels >= 0
|
888 |
+
target_lengths = labels_mask.sum(-1)
|
889 |
+
flattened_targets = labels.masked_select(labels_mask)
|
890 |
+
with torch.backends.cudnn.flags(enabled=False):
|
891 |
+
masked_losses = nn.functional.ctc_loss(log_probs, flattened_targets, input_lengths, target_lengths, blank=pad_token_id, reduction="none", zero_infinity=ctc_zero_infinity)
|
892 |
+
loss = torch.sum(mask * masked_losses) / (torch.sum(mask) + 1e-6)
|
893 |
+
|
894 |
+
with torch.no_grad():
|
895 |
+
thres = 0.5
|
896 |
+
flattened_predictions = Wav2Vec2ForCTCConfig.greedy_decode_ctc(
|
897 |
+
log_probs.transpose(0, 1), # [B, T, V]
|
898 |
+
input_lengths=input_lengths,
|
899 |
+
blank_token_id=pad_token_id,
|
900 |
+
target_lengths=target_lengths
|
901 |
+
)
|
902 |
+
token_wise_mask = torch.eq(flattened_predictions, flattened_targets).to(flattened_targets.dtype)
|
903 |
+
segment_ids = torch.arange(len(target_lengths), device=target_lengths.device).repeat_interleave(target_lengths)
|
904 |
+
sequence_wise_mask = torch.zeros(len(target_lengths), dtype=target_lengths.dtype, device=token_wise_mask.device).scatter_add(0, segment_ids, token_wise_mask)
|
905 |
+
mask = mask * torch.ge(sequence_wise_mask, thres * target_lengths).to(flattened_targets.dtype)
|
906 |
+
|
907 |
+
return loss, mask
|
908 |
+
return func
|
909 |
+
|
910 |
+
@staticmethod
|
911 |
+
def gen_shared_head(self):
|
912 |
+
def func(hidden_states):
|
913 |
+
"""
|
914 |
+
Args:
|
915 |
+
hidden_states (Tensor): Hidden States of shape [B, C, hidden_units].
|
916 |
+
|
917 |
+
Returns:
|
918 |
+
logits (Tensor): Logits tensor of shape [B, C, 2].
|
919 |
+
"""
|
920 |
+
logits = self.lm_head(hidden_states)
|
921 |
+
# log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
922 |
+
return logits
|
923 |
+
return func
|
924 |
+
|
925 |
+
@staticmethod
|
926 |
+
def gen_forward(lambdas, loss_normalization=False):
|
927 |
+
def func(
|
928 |
+
self,
|
929 |
+
input_values: Optional[torch.Tensor],
|
930 |
+
attention_mask: Optional[torch.Tensor] = None,
|
931 |
+
output_attentions: Optional[bool] = None,
|
932 |
+
output_hidden_states: Optional[bool] = None,
|
933 |
+
return_dict: Optional[bool] = None,
|
934 |
+
labels: Optional[torch.Tensor] = None,
|
935 |
+
) -> Union[Tuple, CausalLMOutput]:
|
936 |
+
r"""
|
937 |
+
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
|
938 |
+
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
|
939 |
+
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
|
940 |
+
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
941 |
+
config.vocab_size - 1]`.
|
942 |
+
"""
|
943 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
944 |
+
|
945 |
+
if labels is not None and labels.max() >= self.config.vocab_size:
|
946 |
+
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
947 |
+
|
948 |
+
outputs = self.wav2vec2(
|
949 |
+
input_values,
|
950 |
+
attention_mask=attention_mask,
|
951 |
+
output_attentions=output_attentions,
|
952 |
+
output_hidden_states=output_hidden_states,
|
953 |
+
return_dict=return_dict,
|
954 |
+
)
|
955 |
+
|
956 |
+
hidden_states = outputs[0]
|
957 |
+
hidden_states = self.dropout(hidden_states)
|
958 |
+
|
959 |
+
logits = self.lm_head(hidden_states)
|
960 |
+
|
961 |
+
loss = None
|
962 |
+
if labels is not None:
|
963 |
+
# retrieve loss input_lengths from attention_mask
|
964 |
+
attention_mask = (
|
965 |
+
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
966 |
+
)
|
967 |
+
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
968 |
+
shared_head = Wav2Vec2ForCTCConfig.gen_shared_head(self)
|
969 |
+
criterion = Wav2Vec2ForCTCConfig.gen_criterion(input_lengths, self.config.pad_token_id, self.config.ctc_zero_infinity)
|
970 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_states, logits, labels, loss_normalization) # NOTE: Apply TRP!
|
971 |
+
|
972 |
+
if not return_dict:
|
973 |
+
output = (logits,) + outputs[Wav2Vec2ForCTCConfig._HIDDEN_STATES_START_POSITION:]
|
974 |
+
return ((loss,) + output) if loss is not None else output
|
975 |
+
|
976 |
+
return CausalLMOutput(
|
977 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
978 |
+
)
|
979 |
+
return func
|
980 |
+
|
981 |
+
|
982 |
+
# MBart for Translation
|
983 |
+
class MBartForConditionalGenerationConfig(Config):
|
984 |
+
@staticmethod
|
985 |
+
def gen_criterion(vocab_size: int, top_k=1):
|
986 |
+
def func(logits, labels, mask=None):
|
987 |
+
"""
|
988 |
+
Args:
|
989 |
+
logits (Tensor): Logits tensor of shape [B, L, V].
|
990 |
+
labels (Tensor): Target labels of shape [B, L].
|
991 |
+
|
992 |
+
Returns:
|
993 |
+
loss (Tensor): Scalar tensor representing the loss.
|
994 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
995 |
+
"""
|
996 |
+
if mask is None:
|
997 |
+
mask = torch.ones_like(labels.view(-1), dtype=torch.float32, device=labels.device)
|
998 |
+
|
999 |
+
masked_losses = F.cross_entropy(logits.view(-1, vocab_size), labels.view(-1), reduction="none")
|
1000 |
+
loss = torch.sum(mask * masked_losses) / (torch.sum(mask) + 1e-6)
|
1001 |
+
|
1002 |
+
with torch.no_grad():
|
1003 |
+
topk_values, topk_indices = torch.topk(logits.view(-1, vocab_size), top_k, dim=1)
|
1004 |
+
mask = mask * torch.eq(topk_indices, labels.view(-1, 1)).any(dim=1).to(logits.dtype)
|
1005 |
+
|
1006 |
+
return loss, mask
|
1007 |
+
return func
|
1008 |
+
|
1009 |
+
@staticmethod
|
1010 |
+
def gen_shared_head(self):
|
1011 |
+
def func(hidden_states):
|
1012 |
+
"""
|
1013 |
+
Args:
|
1014 |
+
hidden_states (Tensor): Hidden States of shape [B, L, hidden_units].
|
1015 |
+
|
1016 |
+
Returns:
|
1017 |
+
logits (Tensor): Logits tensor of shape [B, L].
|
1018 |
+
"""
|
1019 |
+
logits = self.lm_head(hidden_states) + self.final_logits_bias
|
1020 |
+
return logits
|
1021 |
+
return func
|
1022 |
+
|
1023 |
+
@staticmethod
|
1024 |
+
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
|
1025 |
+
"""
|
1026 |
+
Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
|
1027 |
+
have a single `decoder_start_token_id` in contrast to other Bart-like models.
|
1028 |
+
"""
|
1029 |
+
prev_output_tokens = input_ids.clone()
|
1030 |
+
|
1031 |
+
if pad_token_id is None:
|
1032 |
+
raise ValueError("self.model.config.pad_token_id has to be defined.")
|
1033 |
+
# replace possible -100 values in labels by `pad_token_id`
|
1034 |
+
prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)
|
1035 |
+
|
1036 |
+
index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
|
1037 |
+
decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
|
1038 |
+
prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
|
1039 |
+
prev_output_tokens[:, 0] = decoder_start_tokens
|
1040 |
+
|
1041 |
+
return prev_output_tokens
|
1042 |
+
|
1043 |
+
@staticmethod
|
1044 |
+
def gen_forward(lambdas, loss_normalization=False):
|
1045 |
+
def func(
|
1046 |
+
self,
|
1047 |
+
input_ids: torch.LongTensor = None,
|
1048 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1049 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
1050 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
1051 |
+
head_mask: Optional[torch.Tensor] = None,
|
1052 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
1053 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
1054 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1055 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1056 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1057 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
1058 |
+
labels: Optional[torch.LongTensor] = None,
|
1059 |
+
use_cache: Optional[bool] = None,
|
1060 |
+
output_attentions: Optional[bool] = None,
|
1061 |
+
output_hidden_states: Optional[bool] = None,
|
1062 |
+
return_dict: Optional[bool] = None,
|
1063 |
+
) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
|
1064 |
+
r"""
|
1065 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1066 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1067 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1068 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1069 |
+
|
1070 |
+
Returns:
|
1071 |
+
|
1072 |
+
"""
|
1073 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1074 |
+
|
1075 |
+
if labels is not None:
|
1076 |
+
# if use_cache:
|
1077 |
+
# logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
1078 |
+
use_cache = False
|
1079 |
+
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
1080 |
+
decoder_input_ids = MBartForConditionalGenerationConfig.shift_tokens_right(labels, self.config.pad_token_id)
|
1081 |
+
|
1082 |
+
outputs = self.model(
|
1083 |
+
input_ids,
|
1084 |
+
attention_mask=attention_mask,
|
1085 |
+
decoder_input_ids=decoder_input_ids,
|
1086 |
+
encoder_outputs=encoder_outputs,
|
1087 |
+
decoder_attention_mask=decoder_attention_mask,
|
1088 |
+
head_mask=head_mask,
|
1089 |
+
decoder_head_mask=decoder_head_mask,
|
1090 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
1091 |
+
past_key_values=past_key_values,
|
1092 |
+
inputs_embeds=inputs_embeds,
|
1093 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
1094 |
+
use_cache=use_cache,
|
1095 |
+
output_attentions=output_attentions,
|
1096 |
+
output_hidden_states=output_hidden_states,
|
1097 |
+
return_dict=return_dict,
|
1098 |
+
)
|
1099 |
+
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
|
1100 |
+
|
1101 |
+
masked_lm_loss = None
|
1102 |
+
if labels is not None:
|
1103 |
+
shared_head = MBartForConditionalGenerationConfig.gen_shared_head(self)
|
1104 |
+
criterion = MBartForConditionalGenerationConfig.gen_criterion(self.config.vocab_size)
|
1105 |
+
masked_lm_loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, outputs[0], lm_logits, labels, loss_normalization)
|
1106 |
+
|
1107 |
+
if not return_dict:
|
1108 |
+
output = (lm_logits,) + outputs[1:]
|
1109 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
1110 |
+
|
1111 |
+
return Seq2SeqLMOutput(
|
1112 |
+
loss=masked_lm_loss,
|
1113 |
+
logits=lm_logits,
|
1114 |
+
past_key_values=outputs.past_key_values,
|
1115 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
1116 |
+
decoder_attentions=outputs.decoder_attentions,
|
1117 |
+
cross_attentions=outputs.cross_attentions,
|
1118 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
1119 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
1120 |
+
encoder_attentions=outputs.encoder_attentions,
|
1121 |
+
)
|
1122 |
+
return func
|
1123 |
+
|
1124 |
+
|
1125 |
+
def apply_trp(model, depths: int, p: float, lambdas: List[float], **kwargs):
|
1126 |
+
if isinstance(model, transformers.Wav2Vec2ForSequenceClassification):
|
1127 |
+
print("✅ Applying TRP to Wav2Vec2 for Audio Classification...")
|
1128 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
|
1129 |
+
model.forward = types.MethodType(Wav2Vec2ForSequenceClassificationConfig.gen_forward(lambdas, False), model)
|
1130 |
+
elif isinstance(model, MobileNetV2):
|
1131 |
+
print("✅ Applying TRP to MobileNetV2 for Image Classification...")
|
1132 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1280, p) for _ in lambdas])
|
1133 |
+
model.forward = types.MethodType(MobileNetV2Config.gen_forward(lambdas, True, label_smoothing=0.0, top_k=1), model)
|
1134 |
+
elif isinstance(model, ResNet):
|
1135 |
+
print("✅ Applying TRP to ResNet for Image Classification...")
|
1136 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 2048, p) for _ in lambdas])
|
1137 |
+
model.forward = types.MethodType(ResNetConfig.gen_forward(lambdas, True, label_smoothing=0.0, top_k=1), model)
|
1138 |
+
elif isinstance(model, EfficientNet):
|
1139 |
+
print("✅ Applying TRP to EfficientNet for Image Classification...")
|
1140 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1280, p) for _ in lambdas])
|
1141 |
+
model.forward = types.MethodType(EfficientNetConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
|
1142 |
+
elif isinstance(model, VisionTransformer):
|
1143 |
+
print("✅ Applying TRP to VisionTransformer for Image Classification...")
|
1144 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
|
1145 |
+
model.forward = types.MethodType(VisionTransformerConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
|
1146 |
+
elif isinstance(model, transformers.BertForQuestionAnswering):
|
1147 |
+
print("✅ Applying TRP to Bert for Question Answering...")
|
1148 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
|
1149 |
+
model.forward = types.MethodType(BertForQuestionAnsweringConfig.gen_forward(lambdas, True, 1), model)
|
1150 |
+
elif isinstance(model, FCN):
|
1151 |
+
print("✅ Applying TRP to FCN for Semantic Segmentation...")
|
1152 |
+
model.out_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 2048, p, dim=1) for _ in lambdas])
|
1153 |
+
model.aux_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p, dim=1) for _ in lambdas])
|
1154 |
+
model.forward = types.MethodType(FCNConfig.gen_forward(lambdas, True, 1), model)
|
1155 |
+
elif isinstance(model, DeepLabV3):
|
1156 |
+
print("✅ Applying TRP to DeepLabV3 for Semantic Segmentation...")
|
1157 |
+
model.out_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 2048, p, dim=1) for _ in lambdas])
|
1158 |
+
model.aux_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p, dim=1) for _ in lambdas])
|
1159 |
+
model.forward = types.MethodType(DeepLabV3Config.gen_forward(lambdas, True, 1), model)
|
1160 |
+
elif isinstance(model, transformers.BertForSequenceClassification):
|
1161 |
+
print("✅ Applying TRP to Bert for Text Classification...")
|
1162 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
|
1163 |
+
model.forward = types.MethodType(BertForSequenceClassificationConfig.gen_forward(lambdas, False), model)
|
1164 |
+
elif isinstance(model, transformers.RobertaForSequenceClassification):
|
1165 |
+
print("✅ Applying TRP to Roberta for Text Classification...")
|
1166 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
|
1167 |
+
model.forward = types.MethodType(RobertaForSequenceClassificationConfig.gen_forward(lambdas, False), model)
|
1168 |
+
elif isinstance(model, transformers.Wav2Vec2ForCTC):
|
1169 |
+
print("✅ Applying TRP to Wav2Vec2 for Speech Recognition...")
|
1170 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p) for _ in lambdas])
|
1171 |
+
model.forward = types.MethodType(Wav2Vec2ForCTCConfig.gen_forward(lambdas, False), model)
|
1172 |
+
elif isinstance(model, transformers.MBartForConditionalGeneration):
|
1173 |
+
print("✅ Applying TRP to MBart for Translation...")
|
1174 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p) for _ in lambdas])
|
1175 |
+
model.forward = types.MethodType(MBartForConditionalGenerationConfig.gen_forward(lambdas, False), model)
|
1176 |
+
else:
|
1177 |
+
torch._assert(
|
1178 |
+
isinstance(model, transformers.Wav2Vec2ForSequenceClassification),
|
1179 |
+
"The model should be an object of [`Wav2Vec2ForSequenceClassification`].")
|
1180 |
+
|
1181 |
+
return model
|
hpo-examples/image-classification/__pycache__/presets.cpython-310.pyc
ADDED
Binary file (2.31 kB). View file
|
|
hpo-examples/image-classification/__pycache__/sampler.cpython-310.pyc
ADDED
Binary file (2.41 kB). View file
|
|
hpo-examples/image-classification/__pycache__/transforms.cpython-310.pyc
ADDED
Binary file (5.29 kB). View file
|
|
hpo-examples/image-classification/__pycache__/trplib.cpython-310.pyc
ADDED
Binary file (37.5 kB). View file
|
|
hpo-examples/image-classification/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (14.9 kB). View file
|
|
hpo-examples/image-classification/efficientnet_v2_m/model_7.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5f4b45a082517a3e60498e92710f5a97b5869515db52e536d9c92a3b68ae4e8f
|
3 |
+
size 454515355
|
hpo-examples/image-classification/mobilenetv2/model_32.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f240e3b954e1b1786878733c3045125461fa67bcb6ce50a200d5ec6e46081bc7
|
3 |
+
size 48002008
|
hpo-examples/image-classification/presets.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision.transforms import autoaugment, transforms
|
3 |
+
from torchvision.transforms.functional import InterpolationMode
|
4 |
+
|
5 |
+
|
6 |
+
class ClassificationPresetTrain:
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
*,
|
10 |
+
crop_size,
|
11 |
+
mean=(0.485, 0.456, 0.406),
|
12 |
+
std=(0.229, 0.224, 0.225),
|
13 |
+
interpolation=InterpolationMode.BILINEAR,
|
14 |
+
hflip_prob=0.5,
|
15 |
+
auto_augment_policy=None,
|
16 |
+
ra_magnitude=9,
|
17 |
+
augmix_severity=3,
|
18 |
+
random_erase_prob=0.0,
|
19 |
+
):
|
20 |
+
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
|
21 |
+
if hflip_prob > 0:
|
22 |
+
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
|
23 |
+
if auto_augment_policy is not None:
|
24 |
+
if auto_augment_policy == "ra":
|
25 |
+
trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
|
26 |
+
elif auto_augment_policy == "ta_wide":
|
27 |
+
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
|
28 |
+
elif auto_augment_policy == "augmix":
|
29 |
+
trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity))
|
30 |
+
else:
|
31 |
+
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
|
32 |
+
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
|
33 |
+
trans.extend(
|
34 |
+
[
|
35 |
+
transforms.PILToTensor(),
|
36 |
+
transforms.ConvertImageDtype(torch.float),
|
37 |
+
transforms.Normalize(mean=mean, std=std),
|
38 |
+
]
|
39 |
+
)
|
40 |
+
if random_erase_prob > 0:
|
41 |
+
trans.append(transforms.RandomErasing(p=random_erase_prob))
|
42 |
+
|
43 |
+
self.transforms = transforms.Compose(trans)
|
44 |
+
|
45 |
+
def __call__(self, img):
|
46 |
+
return self.transforms(img)
|
47 |
+
|
48 |
+
|
49 |
+
class ClassificationPresetEval:
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
*,
|
53 |
+
crop_size,
|
54 |
+
resize_size=256,
|
55 |
+
mean=(0.485, 0.456, 0.406),
|
56 |
+
std=(0.229, 0.224, 0.225),
|
57 |
+
interpolation=InterpolationMode.BILINEAR,
|
58 |
+
):
|
59 |
+
|
60 |
+
self.transforms = transforms.Compose(
|
61 |
+
[
|
62 |
+
transforms.Resize(resize_size, interpolation=interpolation),
|
63 |
+
transforms.CenterCrop(crop_size),
|
64 |
+
transforms.PILToTensor(),
|
65 |
+
transforms.ConvertImageDtype(torch.float),
|
66 |
+
transforms.Normalize(mean=mean, std=std),
|
67 |
+
]
|
68 |
+
)
|
69 |
+
|
70 |
+
def __call__(self, img):
|
71 |
+
return self.transforms(img)
|
hpo-examples/image-classification/resnet50/model_35.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3b304e1309e3143b9e7109d5b7263369e4e93c949dde9bab44b3d3b193d16361
|
3 |
+
size 255177167
|
hpo-examples/image-classification/run.sh
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ✅ --lr 0.00002 Acc@1 71.878 Acc@5 90.286 -> Acc@1 72.104 Acc@5 90.316 (with normalization)
|
2 |
+
torchrun --nproc_per_node=4 train.py\
|
3 |
+
--data-path /home/cs/Documents/datasets/imagenet\
|
4 |
+
--model mobilenet_v2 --output-dir mobilenet_v2 --weights MobileNet_V2_Weights.IMAGENET1K_V1\
|
5 |
+
--batch-size 192 --epochs 40 --lr 0.0004 --lr-step-size 10 --lr-gamma 0.5 --wd 0.00004 --apply-trp --trp-depths 1 --trp-p 0.15 --trp-lambdas 0.4 0.2 0.1
|
6 |
+
# torchrun --nproc_per_node=4 train.py\
|
7 |
+
# --data-path /home/cs/Documents/datasets/imagenet\
|
8 |
+
# --model mobilenet_v2 --resume mobilenet_v2/model_32.pth --test-only
|
9 |
+
|
10 |
+
|
11 |
+
# ✅ --lr 0.0002 Acc@1 76.130 Acc@5 92.862 -> Acc@1 77.234 Acc@5 93.322 (with normalization)
|
12 |
+
torchrun --nproc_per_node=4 train.py\
|
13 |
+
--data-path /home/cs/Documents/datasets/imagenet\
|
14 |
+
--model resnet50 --output-dir resnet50 --weights ResNet50_Weights.IMAGENET1K_V1\
|
15 |
+
--batch-size 64 --epochs 40 --lr 0.0004 --lr-step-size 10 --lr-gamma 0.5 --print-freq 100\
|
16 |
+
--apply-trp --trp-depths 1 --trp-p 0.2 --trp-lambdas 0.4 0.2 0.1
|
17 |
+
# torchrun --nproc_per_node=4 train.py\
|
18 |
+
# --data-path /home/cs/Documents/datasets/imagenet\
|
19 |
+
# --model resnet50 --resume resnet50/model_35.pth --test-only
|
20 |
+
|
21 |
+
|
22 |
+
# ✅ Test: Acc@1 85.218 Acc@5 97.208
|
23 |
+
torchrun --nproc_per_node=4 train.py \
|
24 |
+
--data-path /home/cs/Documents/datasets/imagenet\
|
25 |
+
--model efficientnet_v2_m --output-dir efficientnet_v2_m --weights EfficientNet_V2_M_Weights.IMAGENET1K_V1\
|
26 |
+
--epochs 10 --batch-size 64 --lr 5e-9 --lr-scheduler cosineannealinglr --weight-decay 0.00002 \
|
27 |
+
--lr-warmup-method constant --lr-warmup-epochs 8 --lr-warmup-decay 0. \
|
28 |
+
--auto-augment ta_wide --random-erase 0.1 --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --norm-weight-decay 0.0 \
|
29 |
+
--train-crop-size 384 --val-crop-size 480 --val-resize-size 480 --ra-sampler --ra-reps 4 --print-freq 100\
|
30 |
+
--apply-trp --trp-depths 1 --trp-p 0.2 --trp-lambdas 0.4 0.2 0.1
|
31 |
+
# torchrun --nproc_per_node=4 train.py\
|
32 |
+
# --data-path /home/cs/Documents/datasets/imagenet\
|
33 |
+
# --model efficientnet_v2_m --resume efficientnet_v2_m/model_7.pth --test-only\
|
34 |
+
# --val-crop-size 480 --val-resize-size 480
|
35 |
+
|
36 |
+
|
37 |
+
# ✅ Test: Acc@1 81.092 Acc@5 95.304
|
38 |
+
torchrun --nproc_per_node=4 train.py\
|
39 |
+
--data-path /home/cs/Documents/datasets/imagenet\
|
40 |
+
--model vit_b_16 --output-dir vit_b_16 --weights ViT_B_16_Weights.IMAGENET1K_V1\
|
41 |
+
--epochs 5 --batch-size 196 --opt adamw --lr 5e-9 --lr-scheduler cosineannealinglr --wd 0.3\
|
42 |
+
--lr-warmup-method constant --lr-warmup-epochs 3 --lr-warmup-decay 0. \
|
43 |
+
--amp --label-smoothing 0.11 --mixup-alpha 0.2 --auto-augment ra --clip-grad-norm 1 --cutmix-alpha 1.0\
|
44 |
+
--apply-trp --trp-depths 1 --trp-p 0.1 --trp-lambdas 0.4 0.2 0.1 --print-freq 100
|
45 |
+
# torchrun --nproc_per_node=4 train.py\
|
46 |
+
# --data-path /home/cs/Documents/datasets/imagenet\
|
47 |
+
# --model vit_b_16 --resume vit_b_16/model_4.pth --test-only
|
48 |
+
|
49 |
+
|
hpo-examples/image-classification/sampler.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.distributed as dist
|
5 |
+
|
6 |
+
|
7 |
+
class RASampler(torch.utils.data.Sampler):
|
8 |
+
"""Sampler that restricts data loading to a subset of the dataset for distributed,
|
9 |
+
with repeated augmentation.
|
10 |
+
It ensures that different each augmented version of a sample will be visible to a
|
11 |
+
different process (GPU).
|
12 |
+
Heavily based on 'torch.utils.data.DistributedSampler'.
|
13 |
+
|
14 |
+
This is borrowed from the DeiT Repo:
|
15 |
+
https://github.com/facebookresearch/deit/blob/main/samplers.py
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
|
19 |
+
if num_replicas is None:
|
20 |
+
if not dist.is_available():
|
21 |
+
raise RuntimeError("Requires distributed package to be available!")
|
22 |
+
num_replicas = dist.get_world_size()
|
23 |
+
if rank is None:
|
24 |
+
if not dist.is_available():
|
25 |
+
raise RuntimeError("Requires distributed package to be available!")
|
26 |
+
rank = dist.get_rank()
|
27 |
+
self.dataset = dataset
|
28 |
+
self.num_replicas = num_replicas
|
29 |
+
self.rank = rank
|
30 |
+
self.epoch = 0
|
31 |
+
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
|
32 |
+
self.total_size = self.num_samples * self.num_replicas
|
33 |
+
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
|
34 |
+
self.shuffle = shuffle
|
35 |
+
self.seed = seed
|
36 |
+
self.repetitions = repetitions
|
37 |
+
|
38 |
+
def __iter__(self):
|
39 |
+
if self.shuffle:
|
40 |
+
# Deterministically shuffle based on epoch
|
41 |
+
g = torch.Generator()
|
42 |
+
g.manual_seed(self.seed + self.epoch)
|
43 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
44 |
+
else:
|
45 |
+
indices = list(range(len(self.dataset)))
|
46 |
+
|
47 |
+
# Add extra samples to make it evenly divisible
|
48 |
+
indices = [ele for ele in indices for i in range(self.repetitions)]
|
49 |
+
indices += indices[: (self.total_size - len(indices))]
|
50 |
+
assert len(indices) == self.total_size
|
51 |
+
|
52 |
+
# Subsample
|
53 |
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
54 |
+
assert len(indices) == self.num_samples
|
55 |
+
|
56 |
+
return iter(indices[: self.num_selected_samples])
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return self.num_selected_samples
|
60 |
+
|
61 |
+
def set_epoch(self, epoch):
|
62 |
+
self.epoch = epoch
|
hpo-examples/image-classification/train.py
ADDED
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
import presets
|
7 |
+
import torch
|
8 |
+
import torch.utils.data
|
9 |
+
import torchvision
|
10 |
+
import transforms
|
11 |
+
import utils
|
12 |
+
from sampler import RASampler
|
13 |
+
from torch import nn
|
14 |
+
from torch.utils.data.dataloader import default_collate
|
15 |
+
from torchvision.transforms.functional import InterpolationMode
|
16 |
+
|
17 |
+
from trplib import apply_trp
|
18 |
+
|
19 |
+
|
20 |
+
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
|
21 |
+
model.train()
|
22 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
23 |
+
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
|
24 |
+
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
|
25 |
+
|
26 |
+
header = f"Epoch: [{epoch}]"
|
27 |
+
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
|
28 |
+
start_time = time.time()
|
29 |
+
image, target = image.to(device), target.to(device)
|
30 |
+
with torch.amp.autocast("cuda", enabled=scaler is not None):
|
31 |
+
# output = model(image)
|
32 |
+
# loss = criterion(output, target)
|
33 |
+
output, loss = model(image, target)
|
34 |
+
|
35 |
+
optimizer.zero_grad()
|
36 |
+
if scaler is not None:
|
37 |
+
scaler.scale(loss).backward()
|
38 |
+
if args.clip_grad_norm is not None:
|
39 |
+
# we should unscale the gradients of optimizer's assigned params if do gradient clipping
|
40 |
+
scaler.unscale_(optimizer)
|
41 |
+
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
|
42 |
+
scaler.step(optimizer)
|
43 |
+
scaler.update()
|
44 |
+
else:
|
45 |
+
loss.backward()
|
46 |
+
if args.clip_grad_norm is not None:
|
47 |
+
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
|
48 |
+
optimizer.step()
|
49 |
+
|
50 |
+
if model_ema and i % args.model_ema_steps == 0:
|
51 |
+
model_ema.update_parameters(model)
|
52 |
+
if epoch < args.lr_warmup_epochs:
|
53 |
+
# Reset ema buffer to keep copying weights during warmup period
|
54 |
+
model_ema.n_averaged.fill_(0)
|
55 |
+
|
56 |
+
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
|
57 |
+
batch_size = image.shape[0]
|
58 |
+
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
|
59 |
+
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
|
60 |
+
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
|
61 |
+
metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
|
62 |
+
|
63 |
+
|
64 |
+
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
|
65 |
+
model.eval()
|
66 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
67 |
+
header = f"Test: {log_suffix}"
|
68 |
+
|
69 |
+
num_processed_samples = 0
|
70 |
+
with torch.inference_mode():
|
71 |
+
for image, target in metric_logger.log_every(data_loader, print_freq, header):
|
72 |
+
image = image.to(device, non_blocking=True)
|
73 |
+
target = target.to(device, non_blocking=True)
|
74 |
+
output = model(image)
|
75 |
+
loss = criterion(output, target)
|
76 |
+
|
77 |
+
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
|
78 |
+
# FIXME need to take into account that the datasets
|
79 |
+
# could have been padded in distributed setup
|
80 |
+
batch_size = image.shape[0]
|
81 |
+
metric_logger.update(loss=loss.item())
|
82 |
+
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
|
83 |
+
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
|
84 |
+
num_processed_samples += batch_size
|
85 |
+
# gather the stats from all processes
|
86 |
+
|
87 |
+
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
|
88 |
+
if (
|
89 |
+
hasattr(data_loader.dataset, "__len__")
|
90 |
+
and len(data_loader.dataset) != num_processed_samples
|
91 |
+
and torch.distributed.get_rank() == 0
|
92 |
+
):
|
93 |
+
# See FIXME above
|
94 |
+
warnings.warn(
|
95 |
+
f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
|
96 |
+
"samples were used for the validation, which might bias the results. "
|
97 |
+
"Try adjusting the batch size and / or the world size. "
|
98 |
+
"Setting the world size to 1 is always a safe bet."
|
99 |
+
)
|
100 |
+
|
101 |
+
metric_logger.synchronize_between_processes()
|
102 |
+
|
103 |
+
print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
|
104 |
+
return metric_logger.acc1.global_avg
|
105 |
+
|
106 |
+
|
107 |
+
def _get_cache_path(filepath):
|
108 |
+
import hashlib
|
109 |
+
|
110 |
+
h = hashlib.sha1(filepath.encode()).hexdigest()
|
111 |
+
cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
|
112 |
+
cache_path = os.path.expanduser(cache_path)
|
113 |
+
return cache_path
|
114 |
+
|
115 |
+
|
116 |
+
def load_data(traindir, valdir, args):
|
117 |
+
# Data loading code
|
118 |
+
print("Loading data")
|
119 |
+
val_resize_size, val_crop_size, train_crop_size = (
|
120 |
+
args.val_resize_size,
|
121 |
+
args.val_crop_size,
|
122 |
+
args.train_crop_size,
|
123 |
+
)
|
124 |
+
interpolation = InterpolationMode(args.interpolation)
|
125 |
+
|
126 |
+
print("Loading training data")
|
127 |
+
st = time.time()
|
128 |
+
cache_path = _get_cache_path(traindir)
|
129 |
+
if args.cache_dataset and os.path.exists(cache_path):
|
130 |
+
# Attention, as the transforms are also cached!
|
131 |
+
print(f"Loading dataset_train from {cache_path}")
|
132 |
+
dataset, _ = torch.load(cache_path)
|
133 |
+
else:
|
134 |
+
auto_augment_policy = getattr(args, "auto_augment", None)
|
135 |
+
random_erase_prob = getattr(args, "random_erase", 0.0)
|
136 |
+
ra_magnitude = args.ra_magnitude
|
137 |
+
augmix_severity = args.augmix_severity
|
138 |
+
dataset = torchvision.datasets.ImageFolder(
|
139 |
+
traindir,
|
140 |
+
presets.ClassificationPresetTrain(
|
141 |
+
crop_size=train_crop_size,
|
142 |
+
interpolation=interpolation,
|
143 |
+
auto_augment_policy=auto_augment_policy,
|
144 |
+
random_erase_prob=random_erase_prob,
|
145 |
+
ra_magnitude=ra_magnitude,
|
146 |
+
augmix_severity=augmix_severity,
|
147 |
+
),
|
148 |
+
)
|
149 |
+
if args.cache_dataset:
|
150 |
+
print(f"Saving dataset_train to {cache_path}")
|
151 |
+
utils.mkdir(os.path.dirname(cache_path))
|
152 |
+
utils.save_on_master((dataset, traindir), cache_path)
|
153 |
+
print("Took", time.time() - st)
|
154 |
+
|
155 |
+
print("Loading validation data")
|
156 |
+
cache_path = _get_cache_path(valdir)
|
157 |
+
if args.cache_dataset and os.path.exists(cache_path):
|
158 |
+
# Attention, as the transforms are also cached!
|
159 |
+
print(f"Loading dataset_test from {cache_path}")
|
160 |
+
dataset_test, _ = torch.load(cache_path)
|
161 |
+
else:
|
162 |
+
if args.weights and args.test_only:
|
163 |
+
weights = torchvision.models.get_weight(args.weights)
|
164 |
+
preprocessing = weights.transforms()
|
165 |
+
else:
|
166 |
+
preprocessing = presets.ClassificationPresetEval(
|
167 |
+
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
|
168 |
+
)
|
169 |
+
|
170 |
+
dataset_test = torchvision.datasets.ImageFolder(
|
171 |
+
valdir,
|
172 |
+
preprocessing,
|
173 |
+
)
|
174 |
+
if args.cache_dataset:
|
175 |
+
print(f"Saving dataset_test to {cache_path}")
|
176 |
+
utils.mkdir(os.path.dirname(cache_path))
|
177 |
+
utils.save_on_master((dataset_test, valdir), cache_path)
|
178 |
+
|
179 |
+
print("Creating data loaders")
|
180 |
+
if args.distributed:
|
181 |
+
if hasattr(args, "ra_sampler") and args.ra_sampler:
|
182 |
+
train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
|
183 |
+
else:
|
184 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
185 |
+
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
|
186 |
+
else:
|
187 |
+
train_sampler = torch.utils.data.RandomSampler(dataset)
|
188 |
+
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
189 |
+
|
190 |
+
return dataset, dataset_test, train_sampler, test_sampler
|
191 |
+
|
192 |
+
|
193 |
+
def main(args):
|
194 |
+
if args.output_dir:
|
195 |
+
utils.mkdir(args.output_dir)
|
196 |
+
|
197 |
+
utils.init_distributed_mode(args)
|
198 |
+
print(args)
|
199 |
+
|
200 |
+
device = torch.device(args.device)
|
201 |
+
|
202 |
+
if args.use_deterministic_algorithms:
|
203 |
+
torch.backends.cudnn.benchmark = False
|
204 |
+
torch.use_deterministic_algorithms(True)
|
205 |
+
else:
|
206 |
+
torch.backends.cudnn.benchmark = True
|
207 |
+
|
208 |
+
train_dir = os.path.join(args.data_path, "train")
|
209 |
+
val_dir = os.path.join(args.data_path, "val")
|
210 |
+
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
|
211 |
+
|
212 |
+
collate_fn = None
|
213 |
+
num_classes = len(dataset.classes)
|
214 |
+
mixup_transforms = []
|
215 |
+
if args.mixup_alpha > 0.0:
|
216 |
+
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
|
217 |
+
if args.cutmix_alpha > 0.0:
|
218 |
+
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
|
219 |
+
if mixup_transforms:
|
220 |
+
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
|
221 |
+
|
222 |
+
def collate_fn(batch):
|
223 |
+
return mixupcutmix(*default_collate(batch))
|
224 |
+
|
225 |
+
data_loader = torch.utils.data.DataLoader(
|
226 |
+
dataset,
|
227 |
+
batch_size=args.batch_size,
|
228 |
+
sampler=train_sampler,
|
229 |
+
num_workers=args.workers,
|
230 |
+
pin_memory=True,
|
231 |
+
collate_fn=collate_fn,
|
232 |
+
)
|
233 |
+
data_loader_test = torch.utils.data.DataLoader(
|
234 |
+
dataset_test, batch_size=8, sampler=test_sampler, num_workers=args.workers, pin_memory=True
|
235 |
+
)
|
236 |
+
|
237 |
+
print("Creating model")
|
238 |
+
model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
|
239 |
+
if args.apply_trp:
|
240 |
+
model = apply_trp(model, args.trp_depths, args.trp_p, args.trp_lambdas, label_smoothing=args.label_smoothing)
|
241 |
+
model.to(device)
|
242 |
+
|
243 |
+
if args.distributed and args.sync_bn:
|
244 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
245 |
+
|
246 |
+
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
247 |
+
|
248 |
+
custom_keys_weight_decay = []
|
249 |
+
if args.bias_weight_decay is not None:
|
250 |
+
custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
|
251 |
+
if args.transformer_embedding_decay is not None:
|
252 |
+
for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
|
253 |
+
custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
|
254 |
+
parameters = utils.set_weight_decay(
|
255 |
+
model,
|
256 |
+
args.weight_decay,
|
257 |
+
norm_weight_decay=args.norm_weight_decay,
|
258 |
+
custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
|
259 |
+
)
|
260 |
+
|
261 |
+
opt_name = args.opt.lower()
|
262 |
+
if opt_name.startswith("sgd"):
|
263 |
+
optimizer = torch.optim.SGD(
|
264 |
+
parameters,
|
265 |
+
lr=args.lr,
|
266 |
+
momentum=args.momentum,
|
267 |
+
weight_decay=args.weight_decay,
|
268 |
+
nesterov="nesterov" in opt_name,
|
269 |
+
)
|
270 |
+
elif opt_name == "rmsprop":
|
271 |
+
optimizer = torch.optim.RMSprop(
|
272 |
+
parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
|
273 |
+
)
|
274 |
+
elif opt_name == "adamw":
|
275 |
+
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
|
276 |
+
else:
|
277 |
+
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
|
278 |
+
|
279 |
+
scaler = torch.amp.GradScaler("cuda") if args.amp else None
|
280 |
+
|
281 |
+
args.lr_scheduler = args.lr_scheduler.lower()
|
282 |
+
if args.lr_scheduler == "steplr":
|
283 |
+
main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
|
284 |
+
elif args.lr_scheduler == "cosineannealinglr":
|
285 |
+
main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
286 |
+
optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
|
287 |
+
)
|
288 |
+
elif args.lr_scheduler == "exponentiallr":
|
289 |
+
main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
|
290 |
+
else:
|
291 |
+
raise RuntimeError(
|
292 |
+
f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
|
293 |
+
"are supported."
|
294 |
+
)
|
295 |
+
|
296 |
+
if args.lr_warmup_epochs > 0:
|
297 |
+
if args.lr_warmup_method == "linear":
|
298 |
+
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
|
299 |
+
optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
|
300 |
+
)
|
301 |
+
elif args.lr_warmup_method == "constant":
|
302 |
+
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
|
303 |
+
optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
|
304 |
+
)
|
305 |
+
else:
|
306 |
+
raise RuntimeError(
|
307 |
+
f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
|
308 |
+
)
|
309 |
+
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
|
310 |
+
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
|
311 |
+
)
|
312 |
+
else:
|
313 |
+
lr_scheduler = main_lr_scheduler
|
314 |
+
|
315 |
+
model_without_ddp = model
|
316 |
+
if args.distributed:
|
317 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
318 |
+
model_without_ddp = model.module
|
319 |
+
|
320 |
+
model_ema = None
|
321 |
+
if args.model_ema:
|
322 |
+
# Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
|
323 |
+
# https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
|
324 |
+
#
|
325 |
+
# total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
|
326 |
+
# We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
|
327 |
+
# adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
|
328 |
+
adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
|
329 |
+
alpha = 1.0 - args.model_ema_decay
|
330 |
+
alpha = min(1.0, alpha * adjust)
|
331 |
+
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
|
332 |
+
|
333 |
+
if args.resume:
|
334 |
+
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=False)
|
335 |
+
model_without_ddp.load_state_dict(checkpoint["model"])
|
336 |
+
if not args.test_only:
|
337 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
338 |
+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
339 |
+
args.start_epoch = checkpoint["epoch"] + 1
|
340 |
+
if model_ema:
|
341 |
+
model_ema.load_state_dict(checkpoint["model_ema"])
|
342 |
+
if scaler:
|
343 |
+
scaler.load_state_dict(checkpoint["scaler"])
|
344 |
+
|
345 |
+
if args.test_only:
|
346 |
+
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
|
347 |
+
torch.backends.cudnn.benchmark = False
|
348 |
+
torch.backends.cudnn.deterministic = True
|
349 |
+
if model_ema:
|
350 |
+
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
|
351 |
+
else:
|
352 |
+
evaluate(model, criterion, data_loader_test, device=device)
|
353 |
+
return
|
354 |
+
|
355 |
+
print("Start training")
|
356 |
+
start_time = time.time()
|
357 |
+
for epoch in range(args.start_epoch, args.epochs):
|
358 |
+
if args.distributed:
|
359 |
+
train_sampler.set_epoch(epoch)
|
360 |
+
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
|
361 |
+
lr_scheduler.step()
|
362 |
+
evaluate(model, criterion, data_loader_test, device=device)
|
363 |
+
if model_ema:
|
364 |
+
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
|
365 |
+
if args.output_dir:
|
366 |
+
checkpoint = {
|
367 |
+
"model": model_without_ddp.state_dict() if not args.apply_trp else {k: v for k, v in model_without_ddp.state_dict().items() if not k.startswith("trp_blocks")}, # NOTE: remove TRP heads
|
368 |
+
"optimizer": optimizer.state_dict(),
|
369 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
370 |
+
"epoch": epoch,
|
371 |
+
"args": args,
|
372 |
+
}
|
373 |
+
if model_ema:
|
374 |
+
checkpoint["model_ema"] = model_ema.state_dict() if not args.apply_trp else {k: v for k, v in model_ema.state_dict().items() if not k.startswith("trp_blocks")} # NOTE: remove TRP heads
|
375 |
+
if scaler:
|
376 |
+
checkpoint["scaler"] = scaler.state_dict()
|
377 |
+
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
|
378 |
+
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
|
379 |
+
|
380 |
+
total_time = time.time() - start_time
|
381 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
382 |
+
print(f"Training time {total_time_str}")
|
383 |
+
|
384 |
+
|
385 |
+
def get_args_parser(add_help=True):
|
386 |
+
import argparse
|
387 |
+
|
388 |
+
parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
|
389 |
+
|
390 |
+
parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
|
391 |
+
parser.add_argument("--model", default="resnet18", type=str, help="model name")
|
392 |
+
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
|
393 |
+
parser.add_argument(
|
394 |
+
"-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
|
395 |
+
)
|
396 |
+
parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
|
397 |
+
parser.add_argument(
|
398 |
+
"-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
|
399 |
+
)
|
400 |
+
parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
|
401 |
+
parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
|
402 |
+
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
|
403 |
+
parser.add_argument(
|
404 |
+
"--wd",
|
405 |
+
"--weight-decay",
|
406 |
+
default=1e-4,
|
407 |
+
type=float,
|
408 |
+
metavar="W",
|
409 |
+
help="weight decay (default: 1e-4)",
|
410 |
+
dest="weight_decay",
|
411 |
+
)
|
412 |
+
parser.add_argument(
|
413 |
+
"--norm-weight-decay",
|
414 |
+
default=None,
|
415 |
+
type=float,
|
416 |
+
help="weight decay for Normalization layers (default: None, same value as --wd)",
|
417 |
+
)
|
418 |
+
parser.add_argument(
|
419 |
+
"--bias-weight-decay",
|
420 |
+
default=None,
|
421 |
+
type=float,
|
422 |
+
help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
|
423 |
+
)
|
424 |
+
parser.add_argument(
|
425 |
+
"--transformer-embedding-decay",
|
426 |
+
default=None,
|
427 |
+
type=float,
|
428 |
+
help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
|
429 |
+
)
|
430 |
+
parser.add_argument(
|
431 |
+
"--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
|
432 |
+
)
|
433 |
+
parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
|
434 |
+
parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
|
435 |
+
parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
|
436 |
+
parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
|
437 |
+
parser.add_argument(
|
438 |
+
"--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
|
439 |
+
)
|
440 |
+
parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
|
441 |
+
parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
|
442 |
+
parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
|
443 |
+
parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
|
444 |
+
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
|
445 |
+
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
|
446 |
+
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
|
447 |
+
parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
|
448 |
+
parser.add_argument(
|
449 |
+
"--cache-dataset",
|
450 |
+
dest="cache_dataset",
|
451 |
+
help="Cache the datasets for quicker initialization. It also serializes the transforms",
|
452 |
+
action="store_true",
|
453 |
+
)
|
454 |
+
parser.add_argument(
|
455 |
+
"--sync-bn",
|
456 |
+
dest="sync_bn",
|
457 |
+
help="Use sync batch norm",
|
458 |
+
action="store_true",
|
459 |
+
)
|
460 |
+
parser.add_argument(
|
461 |
+
"--test-only",
|
462 |
+
dest="test_only",
|
463 |
+
help="Only test the model",
|
464 |
+
action="store_true",
|
465 |
+
)
|
466 |
+
parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
|
467 |
+
parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
|
468 |
+
parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
|
469 |
+
parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
|
470 |
+
|
471 |
+
# Mixed precision training parameters
|
472 |
+
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
|
473 |
+
|
474 |
+
# distributed training parameters
|
475 |
+
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
|
476 |
+
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
|
477 |
+
parser.add_argument(
|
478 |
+
"--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
|
479 |
+
)
|
480 |
+
parser.add_argument(
|
481 |
+
"--model-ema-steps",
|
482 |
+
type=int,
|
483 |
+
default=32,
|
484 |
+
help="the number of iterations that controls how often to update the EMA model (default: 32)",
|
485 |
+
)
|
486 |
+
parser.add_argument(
|
487 |
+
"--model-ema-decay",
|
488 |
+
type=float,
|
489 |
+
default=0.99998,
|
490 |
+
help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
|
491 |
+
)
|
492 |
+
parser.add_argument(
|
493 |
+
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
|
494 |
+
)
|
495 |
+
parser.add_argument(
|
496 |
+
"--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
|
497 |
+
)
|
498 |
+
parser.add_argument(
|
499 |
+
"--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
|
500 |
+
)
|
501 |
+
parser.add_argument(
|
502 |
+
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
|
503 |
+
)
|
504 |
+
parser.add_argument(
|
505 |
+
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
|
506 |
+
)
|
507 |
+
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
|
508 |
+
parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
|
509 |
+
parser.add_argument(
|
510 |
+
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
|
511 |
+
)
|
512 |
+
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
|
513 |
+
|
514 |
+
parser.add_argument("--apply-trp", action="store_true", help="enable applying trp")
|
515 |
+
parser.add_argument("--trp-depths", type=int, help="trp depth")
|
516 |
+
parser.add_argument("--trp-p", type=float, help="trp p")
|
517 |
+
parser.add_argument("--trp-lambdas", nargs="+", type=float, help="trp lambdas")
|
518 |
+
|
519 |
+
return parser
|
520 |
+
|
521 |
+
|
522 |
+
if __name__ == "__main__":
|
523 |
+
args = get_args_parser().parse_args()
|
524 |
+
main(args)
|
hpo-examples/image-classification/train_quantization.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import datetime
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.ao.quantization
|
8 |
+
import torch.utils.data
|
9 |
+
import torchvision
|
10 |
+
import utils
|
11 |
+
from torch import nn
|
12 |
+
from train import evaluate, load_data, train_one_epoch
|
13 |
+
|
14 |
+
|
15 |
+
def main(args):
|
16 |
+
if args.output_dir:
|
17 |
+
utils.mkdir(args.output_dir)
|
18 |
+
|
19 |
+
utils.init_distributed_mode(args)
|
20 |
+
print(args)
|
21 |
+
|
22 |
+
if args.post_training_quantize and args.distributed:
|
23 |
+
raise RuntimeError("Post training quantization example should not be performed on distributed mode")
|
24 |
+
|
25 |
+
# Set backend engine to ensure that quantized model runs on the correct kernels
|
26 |
+
if args.backend not in torch.backends.quantized.supported_engines:
|
27 |
+
raise RuntimeError("Quantized backend not supported: " + str(args.backend))
|
28 |
+
torch.backends.quantized.engine = args.backend
|
29 |
+
|
30 |
+
device = torch.device(args.device)
|
31 |
+
torch.backends.cudnn.benchmark = True
|
32 |
+
|
33 |
+
# Data loading code
|
34 |
+
print("Loading data")
|
35 |
+
train_dir = os.path.join(args.data_path, "train")
|
36 |
+
val_dir = os.path.join(args.data_path, "val")
|
37 |
+
|
38 |
+
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
|
39 |
+
data_loader = torch.utils.data.DataLoader(
|
40 |
+
dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
|
41 |
+
)
|
42 |
+
|
43 |
+
data_loader_test = torch.utils.data.DataLoader(
|
44 |
+
dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
|
45 |
+
)
|
46 |
+
|
47 |
+
print("Creating model", args.model)
|
48 |
+
# when training quantized models, we always start from a pre-trained fp32 reference model
|
49 |
+
prefix = "quantized_"
|
50 |
+
model_name = args.model
|
51 |
+
if not model_name.startswith(prefix):
|
52 |
+
model_name = prefix + model_name
|
53 |
+
model = torchvision.models.get_model(model_name, weights=args.weights, quantize=args.test_only)
|
54 |
+
model.to(device)
|
55 |
+
|
56 |
+
if not (args.test_only or args.post_training_quantize):
|
57 |
+
model.fuse_model(is_qat=True)
|
58 |
+
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
|
59 |
+
torch.ao.quantization.prepare_qat(model, inplace=True)
|
60 |
+
|
61 |
+
if args.distributed and args.sync_bn:
|
62 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
63 |
+
|
64 |
+
optimizer = torch.optim.SGD(
|
65 |
+
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
|
66 |
+
)
|
67 |
+
|
68 |
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
|
69 |
+
|
70 |
+
criterion = nn.CrossEntropyLoss()
|
71 |
+
model_without_ddp = model
|
72 |
+
if args.distributed:
|
73 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
74 |
+
model_without_ddp = model.module
|
75 |
+
|
76 |
+
if args.resume:
|
77 |
+
checkpoint = torch.load(args.resume, map_location="cpu")
|
78 |
+
model_without_ddp.load_state_dict(checkpoint["model"])
|
79 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
80 |
+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
81 |
+
args.start_epoch = checkpoint["epoch"] + 1
|
82 |
+
|
83 |
+
if args.post_training_quantize:
|
84 |
+
# perform calibration on a subset of the training dataset
|
85 |
+
# for that, create a subset of the training dataset
|
86 |
+
ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches)))
|
87 |
+
data_loader_calibration = torch.utils.data.DataLoader(
|
88 |
+
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
|
89 |
+
)
|
90 |
+
model.eval()
|
91 |
+
model.fuse_model(is_qat=False)
|
92 |
+
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
|
93 |
+
torch.ao.quantization.prepare(model, inplace=True)
|
94 |
+
# Calibrate first
|
95 |
+
print("Calibrating")
|
96 |
+
evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
|
97 |
+
torch.ao.quantization.convert(model, inplace=True)
|
98 |
+
if args.output_dir:
|
99 |
+
print("Saving quantized model")
|
100 |
+
if utils.is_main_process():
|
101 |
+
torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth"))
|
102 |
+
print("Evaluating post-training quantized model")
|
103 |
+
evaluate(model, criterion, data_loader_test, device=device)
|
104 |
+
return
|
105 |
+
|
106 |
+
if args.test_only:
|
107 |
+
evaluate(model, criterion, data_loader_test, device=device)
|
108 |
+
return
|
109 |
+
|
110 |
+
model.apply(torch.ao.quantization.enable_observer)
|
111 |
+
model.apply(torch.ao.quantization.enable_fake_quant)
|
112 |
+
start_time = time.time()
|
113 |
+
for epoch in range(args.start_epoch, args.epochs):
|
114 |
+
if args.distributed:
|
115 |
+
train_sampler.set_epoch(epoch)
|
116 |
+
print("Starting training for epoch", epoch)
|
117 |
+
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
|
118 |
+
lr_scheduler.step()
|
119 |
+
with torch.inference_mode():
|
120 |
+
if epoch >= args.num_observer_update_epochs:
|
121 |
+
print("Disabling observer for subseq epochs, epoch = ", epoch)
|
122 |
+
model.apply(torch.ao.quantization.disable_observer)
|
123 |
+
if epoch >= args.num_batch_norm_update_epochs:
|
124 |
+
print("Freezing BN for subseq epochs, epoch = ", epoch)
|
125 |
+
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
|
126 |
+
print("Evaluate QAT model")
|
127 |
+
|
128 |
+
evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT")
|
129 |
+
quantized_eval_model = copy.deepcopy(model_without_ddp)
|
130 |
+
quantized_eval_model.eval()
|
131 |
+
quantized_eval_model.to(torch.device("cpu"))
|
132 |
+
torch.ao.quantization.convert(quantized_eval_model, inplace=True)
|
133 |
+
|
134 |
+
print("Evaluate Quantized model")
|
135 |
+
evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
|
136 |
+
|
137 |
+
model.train()
|
138 |
+
|
139 |
+
if args.output_dir:
|
140 |
+
checkpoint = {
|
141 |
+
"model": model_without_ddp.state_dict(),
|
142 |
+
"eval_model": quantized_eval_model.state_dict(),
|
143 |
+
"optimizer": optimizer.state_dict(),
|
144 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
145 |
+
"epoch": epoch,
|
146 |
+
"args": args,
|
147 |
+
}
|
148 |
+
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
|
149 |
+
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
|
150 |
+
print("Saving models after epoch ", epoch)
|
151 |
+
|
152 |
+
total_time = time.time() - start_time
|
153 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
154 |
+
print(f"Training time {total_time_str}")
|
155 |
+
|
156 |
+
|
157 |
+
def get_args_parser(add_help=True):
|
158 |
+
import argparse
|
159 |
+
|
160 |
+
parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help)
|
161 |
+
|
162 |
+
parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
|
163 |
+
parser.add_argument("--model", default="mobilenet_v2", type=str, help="model name")
|
164 |
+
parser.add_argument("--backend", default="qnnpack", type=str, help="fbgemm or qnnpack")
|
165 |
+
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
|
166 |
+
|
167 |
+
parser.add_argument(
|
168 |
+
"-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
|
169 |
+
)
|
170 |
+
parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation")
|
171 |
+
parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
|
172 |
+
parser.add_argument(
|
173 |
+
"--num-observer-update-epochs",
|
174 |
+
default=4,
|
175 |
+
type=int,
|
176 |
+
metavar="N",
|
177 |
+
help="number of total epochs to update observers",
|
178 |
+
)
|
179 |
+
parser.add_argument(
|
180 |
+
"--num-batch-norm-update-epochs",
|
181 |
+
default=3,
|
182 |
+
type=int,
|
183 |
+
metavar="N",
|
184 |
+
help="number of total epochs to update batch norm stats",
|
185 |
+
)
|
186 |
+
parser.add_argument(
|
187 |
+
"--num-calibration-batches",
|
188 |
+
default=32,
|
189 |
+
type=int,
|
190 |
+
metavar="N",
|
191 |
+
help="number of batches of training set for \
|
192 |
+
observer calibration ",
|
193 |
+
)
|
194 |
+
|
195 |
+
parser.add_argument(
|
196 |
+
"-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
|
197 |
+
)
|
198 |
+
parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate")
|
199 |
+
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
|
200 |
+
parser.add_argument(
|
201 |
+
"--wd",
|
202 |
+
"--weight-decay",
|
203 |
+
default=1e-4,
|
204 |
+
type=float,
|
205 |
+
metavar="W",
|
206 |
+
help="weight decay (default: 1e-4)",
|
207 |
+
dest="weight_decay",
|
208 |
+
)
|
209 |
+
parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
|
210 |
+
parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
|
211 |
+
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
|
212 |
+
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
|
213 |
+
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
|
214 |
+
parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
|
215 |
+
parser.add_argument(
|
216 |
+
"--cache-dataset",
|
217 |
+
dest="cache_dataset",
|
218 |
+
help="Cache the datasets for quicker initialization. \
|
219 |
+
It also serializes the transforms",
|
220 |
+
action="store_true",
|
221 |
+
)
|
222 |
+
parser.add_argument(
|
223 |
+
"--sync-bn",
|
224 |
+
dest="sync_bn",
|
225 |
+
help="Use sync batch norm",
|
226 |
+
action="store_true",
|
227 |
+
)
|
228 |
+
parser.add_argument(
|
229 |
+
"--test-only",
|
230 |
+
dest="test_only",
|
231 |
+
help="Only test the model",
|
232 |
+
action="store_true",
|
233 |
+
)
|
234 |
+
parser.add_argument(
|
235 |
+
"--post-training-quantize",
|
236 |
+
dest="post_training_quantize",
|
237 |
+
help="Post training quantize the model",
|
238 |
+
action="store_true",
|
239 |
+
)
|
240 |
+
|
241 |
+
# distributed training parameters
|
242 |
+
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
|
243 |
+
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
|
244 |
+
|
245 |
+
parser.add_argument(
|
246 |
+
"--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
|
247 |
+
)
|
248 |
+
parser.add_argument(
|
249 |
+
"--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
|
250 |
+
)
|
251 |
+
parser.add_argument(
|
252 |
+
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
|
253 |
+
)
|
254 |
+
parser.add_argument(
|
255 |
+
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
|
256 |
+
)
|
257 |
+
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
|
258 |
+
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
|
259 |
+
|
260 |
+
return parser
|
261 |
+
|
262 |
+
|
263 |
+
if __name__ == "__main__":
|
264 |
+
args = get_args_parser().parse_args()
|
265 |
+
main(args)
|
hpo-examples/image-classification/transforms.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import Tensor
|
6 |
+
from torchvision.transforms import functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class RandomMixup(torch.nn.Module):
|
10 |
+
"""Randomly apply Mixup to the provided batch and targets.
|
11 |
+
The class implements the data augmentations as described in the paper
|
12 |
+
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
num_classes (int): number of classes used for one-hot encoding.
|
16 |
+
p (float): probability of the batch being transformed. Default value is 0.5.
|
17 |
+
alpha (float): hyperparameter of the Beta distribution used for mixup.
|
18 |
+
Default value is 1.0.
|
19 |
+
inplace (bool): boolean to make this transform inplace. Default set to False.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
if num_classes < 1:
|
26 |
+
raise ValueError(
|
27 |
+
f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
|
28 |
+
)
|
29 |
+
|
30 |
+
if alpha <= 0:
|
31 |
+
raise ValueError("Alpha param can't be zero.")
|
32 |
+
|
33 |
+
self.num_classes = num_classes
|
34 |
+
self.p = p
|
35 |
+
self.alpha = alpha
|
36 |
+
self.inplace = inplace
|
37 |
+
|
38 |
+
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
|
39 |
+
"""
|
40 |
+
Args:
|
41 |
+
batch (Tensor): Float tensor of size (B, C, H, W)
|
42 |
+
target (Tensor): Integer tensor of size (B, )
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
Tensor: Randomly transformed batch.
|
46 |
+
"""
|
47 |
+
if batch.ndim != 4:
|
48 |
+
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
|
49 |
+
if target.ndim != 1:
|
50 |
+
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
|
51 |
+
if not batch.is_floating_point():
|
52 |
+
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
|
53 |
+
if target.dtype != torch.int64:
|
54 |
+
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
|
55 |
+
|
56 |
+
if not self.inplace:
|
57 |
+
batch = batch.clone()
|
58 |
+
target = target.clone()
|
59 |
+
|
60 |
+
if target.ndim == 1:
|
61 |
+
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
|
62 |
+
|
63 |
+
if torch.rand(1).item() >= self.p:
|
64 |
+
return batch, target
|
65 |
+
|
66 |
+
# It's faster to roll the batch by one instead of shuffling it to create image pairs
|
67 |
+
batch_rolled = batch.roll(1, 0)
|
68 |
+
target_rolled = target.roll(1, 0)
|
69 |
+
|
70 |
+
# Implemented as on mixup paper, page 3.
|
71 |
+
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
|
72 |
+
batch_rolled.mul_(1.0 - lambda_param)
|
73 |
+
batch.mul_(lambda_param).add_(batch_rolled)
|
74 |
+
|
75 |
+
target_rolled.mul_(1.0 - lambda_param)
|
76 |
+
target.mul_(lambda_param).add_(target_rolled)
|
77 |
+
|
78 |
+
return batch, target
|
79 |
+
|
80 |
+
def __repr__(self) -> str:
|
81 |
+
s = (
|
82 |
+
f"{self.__class__.__name__}("
|
83 |
+
f"num_classes={self.num_classes}"
|
84 |
+
f", p={self.p}"
|
85 |
+
f", alpha={self.alpha}"
|
86 |
+
f", inplace={self.inplace}"
|
87 |
+
f")"
|
88 |
+
)
|
89 |
+
return s
|
90 |
+
|
91 |
+
|
92 |
+
class RandomCutmix(torch.nn.Module):
|
93 |
+
"""Randomly apply Cutmix to the provided batch and targets.
|
94 |
+
The class implements the data augmentations as described in the paper
|
95 |
+
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
|
96 |
+
<https://arxiv.org/abs/1905.04899>`_.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
num_classes (int): number of classes used for one-hot encoding.
|
100 |
+
p (float): probability of the batch being transformed. Default value is 0.5.
|
101 |
+
alpha (float): hyperparameter of the Beta distribution used for cutmix.
|
102 |
+
Default value is 1.0.
|
103 |
+
inplace (bool): boolean to make this transform inplace. Default set to False.
|
104 |
+
"""
|
105 |
+
|
106 |
+
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
|
107 |
+
super().__init__()
|
108 |
+
if num_classes < 1:
|
109 |
+
raise ValueError("Please provide a valid positive value for the num_classes.")
|
110 |
+
if alpha <= 0:
|
111 |
+
raise ValueError("Alpha param can't be zero.")
|
112 |
+
|
113 |
+
self.num_classes = num_classes
|
114 |
+
self.p = p
|
115 |
+
self.alpha = alpha
|
116 |
+
self.inplace = inplace
|
117 |
+
|
118 |
+
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
|
119 |
+
"""
|
120 |
+
Args:
|
121 |
+
batch (Tensor): Float tensor of size (B, C, H, W)
|
122 |
+
target (Tensor): Integer tensor of size (B, )
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
Tensor: Randomly transformed batch.
|
126 |
+
"""
|
127 |
+
if batch.ndim != 4:
|
128 |
+
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
|
129 |
+
if target.ndim != 1:
|
130 |
+
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
|
131 |
+
if not batch.is_floating_point():
|
132 |
+
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
|
133 |
+
if target.dtype != torch.int64:
|
134 |
+
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
|
135 |
+
|
136 |
+
if not self.inplace:
|
137 |
+
batch = batch.clone()
|
138 |
+
target = target.clone()
|
139 |
+
|
140 |
+
if target.ndim == 1:
|
141 |
+
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
|
142 |
+
|
143 |
+
if torch.rand(1).item() >= self.p:
|
144 |
+
return batch, target
|
145 |
+
|
146 |
+
# It's faster to roll the batch by one instead of shuffling it to create image pairs
|
147 |
+
batch_rolled = batch.roll(1, 0)
|
148 |
+
target_rolled = target.roll(1, 0)
|
149 |
+
|
150 |
+
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
|
151 |
+
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
|
152 |
+
_, H, W = F.get_dimensions(batch)
|
153 |
+
|
154 |
+
r_x = torch.randint(W, (1,))
|
155 |
+
r_y = torch.randint(H, (1,))
|
156 |
+
|
157 |
+
r = 0.5 * math.sqrt(1.0 - lambda_param)
|
158 |
+
r_w_half = int(r * W)
|
159 |
+
r_h_half = int(r * H)
|
160 |
+
|
161 |
+
x1 = int(torch.clamp(r_x - r_w_half, min=0))
|
162 |
+
y1 = int(torch.clamp(r_y - r_h_half, min=0))
|
163 |
+
x2 = int(torch.clamp(r_x + r_w_half, max=W))
|
164 |
+
y2 = int(torch.clamp(r_y + r_h_half, max=H))
|
165 |
+
|
166 |
+
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
|
167 |
+
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
|
168 |
+
|
169 |
+
target_rolled.mul_(1.0 - lambda_param)
|
170 |
+
target.mul_(lambda_param).add_(target_rolled)
|
171 |
+
|
172 |
+
return batch, target
|
173 |
+
|
174 |
+
def __repr__(self) -> str:
|
175 |
+
s = (
|
176 |
+
f"{self.__class__.__name__}("
|
177 |
+
f"num_classes={self.num_classes}"
|
178 |
+
f", p={self.p}"
|
179 |
+
f", alpha={self.alpha}"
|
180 |
+
f", inplace={self.inplace}"
|
181 |
+
f")"
|
182 |
+
)
|
183 |
+
return s
|
hpo-examples/image-classification/trplib.py
ADDED
@@ -0,0 +1,1181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, Tensor
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from torchvision.models.mobilenetv2 import MobileNetV2
|
6 |
+
from torchvision.models.resnet import ResNet
|
7 |
+
from torchvision.models.efficientnet import EfficientNet
|
8 |
+
from torchvision.models.vision_transformer import VisionTransformer
|
9 |
+
from torchvision.models.segmentation.fcn import FCN
|
10 |
+
from torchvision.models.segmentation.deeplabv3 import DeepLabV3
|
11 |
+
|
12 |
+
import transformers
|
13 |
+
from transformers.modeling_outputs import SequenceClassifierOutput, QuestionAnsweringModelOutput, CausalLMOutput, Seq2SeqLMOutput
|
14 |
+
|
15 |
+
from typing import Optional, Tuple, List, Union, Callable
|
16 |
+
from collections import OrderedDict
|
17 |
+
import types
|
18 |
+
|
19 |
+
|
20 |
+
def trp_criterion(trp_blocks: nn.ModuleList, shared_head: Callable, criterion: Callable, lambdas: List[float], hidden_states: Tensor, logits: Tensor, targets: Tensor, loss_normalization=False):
|
21 |
+
loss, mask = criterion(logits, targets)
|
22 |
+
if loss_normalization:
|
23 |
+
coeff = loss.detach()
|
24 |
+
|
25 |
+
embeds = [hidden_states]
|
26 |
+
predictions = []
|
27 |
+
for k, c in enumerate(lambdas):
|
28 |
+
embeds.append(trp_blocks[k](embeds[-1]))
|
29 |
+
predictions.append(shared_head(embeds[-1]))
|
30 |
+
replica_loss, mask = criterion(predictions[-1], targets, mask)
|
31 |
+
loss += c * replica_loss
|
32 |
+
|
33 |
+
if loss_normalization:
|
34 |
+
with torch.no_grad():
|
35 |
+
coeff = torch.exp(coeff) / torch.exp(loss.detach())
|
36 |
+
loss = coeff * loss
|
37 |
+
|
38 |
+
return loss
|
39 |
+
|
40 |
+
|
41 |
+
class TPBlock(nn.Module):
|
42 |
+
def __init__(self, depths: int, in_features: int, p: float, dim=-1):
|
43 |
+
super(TPBlock, self).__init__()
|
44 |
+
|
45 |
+
self.dropout = nn.Dropout(p)
|
46 |
+
|
47 |
+
self.cdim = dim
|
48 |
+
|
49 |
+
blocks = []
|
50 |
+
for _ in range(depths):
|
51 |
+
blocks.append(nn.Linear(in_features, in_features))
|
52 |
+
nn.init.constant_(blocks[-1].weight, 0.0)
|
53 |
+
nn.init.constant_(blocks[-1].bias, 0.0)
|
54 |
+
blocks.append(nn.ReLU())
|
55 |
+
self.blocks = nn.Sequential(*blocks)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
x = self.dropout(x)
|
59 |
+
if self.cdim == -1:
|
60 |
+
x = x + self.blocks(x)
|
61 |
+
else:
|
62 |
+
x = x + torch.movedim(self.blocks(torch.movedim(x, self.cdim, -1)), -1, self.cdim)
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
class Config:
|
67 |
+
@staticmethod
|
68 |
+
def gen_criterion(*args, **kwargs):
|
69 |
+
def func(input, target, mask=None):
|
70 |
+
"""
|
71 |
+
Args:
|
72 |
+
input (Tensor): Input tensor.
|
73 |
+
target (Tensor): Target labels.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
loss (Tensor): Scalar tensor representing the loss.
|
77 |
+
mask (Tensor): Boolean mask tensor with the same shape of target.
|
78 |
+
"""
|
79 |
+
pass
|
80 |
+
return func
|
81 |
+
|
82 |
+
@staticmethod
|
83 |
+
def gen_shared_head(*args, **kwargs):
|
84 |
+
def func(hidden_states):
|
85 |
+
"""
|
86 |
+
Args:
|
87 |
+
hidden_states (Tensor): Hidden States tensor.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
logits (Tensor): Logits tensor.
|
91 |
+
"""
|
92 |
+
pass
|
93 |
+
return func
|
94 |
+
|
95 |
+
@staticmethod
|
96 |
+
def forward(*args, **kwargs):
|
97 |
+
pass
|
98 |
+
|
99 |
+
|
100 |
+
# Wav2Vec2 for Audio Classification
|
101 |
+
class Wav2Vec2ForSequenceClassificationConfig(Config):
|
102 |
+
_HIDDEN_STATES_START_POSITION = 2
|
103 |
+
|
104 |
+
@staticmethod
|
105 |
+
def gen_criterion():
|
106 |
+
def func(input, target, mask=None):
|
107 |
+
"""
|
108 |
+
Args:
|
109 |
+
input (Tensor): Input tensor of shape [B, C].
|
110 |
+
target (Tensor): Target labels of shape [B].
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
loss (Tensor): Scalar tensor representing the loss.
|
114 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
115 |
+
"""
|
116 |
+
if mask is None:
|
117 |
+
mask = torch.ones_like(target, dtype=torch.float32, device=target.device)
|
118 |
+
|
119 |
+
unmasked_loss = F.cross_entropy(input, target, reduction="none")
|
120 |
+
loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
|
121 |
+
|
122 |
+
with torch.no_grad():
|
123 |
+
mask = mask * torch.eq(torch.argmax(input, dim=1), target).to(input.dtype)
|
124 |
+
|
125 |
+
return loss, mask
|
126 |
+
return func
|
127 |
+
|
128 |
+
@staticmethod
|
129 |
+
def gen_shared_head(self, attention_mask):
|
130 |
+
def func(hidden_states):
|
131 |
+
"""
|
132 |
+
Args:
|
133 |
+
hidden_states (Tensor): Hidden States of shape [B, L, hidden_units].
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
137 |
+
"""
|
138 |
+
_hidden_states = self.projector(hidden_states)
|
139 |
+
if attention_mask is None:
|
140 |
+
pooled_output = _hidden_states.mean(dim=1)
|
141 |
+
else:
|
142 |
+
padding_mask = self._get_feature_vector_attention_mask(_hidden_states.shape[1], attention_mask)
|
143 |
+
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, _hidden_states.shape[2])
|
144 |
+
_hidden_states[~expand_padding_mask] = 0.0
|
145 |
+
pooled_output = _hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
146 |
+
|
147 |
+
logits = self.classifier(pooled_output)
|
148 |
+
return logits
|
149 |
+
return func
|
150 |
+
|
151 |
+
@staticmethod
|
152 |
+
def gen_forward(lambdas, loss_normalization=False):
|
153 |
+
def func(
|
154 |
+
self,
|
155 |
+
input_values: Optional[torch.Tensor],
|
156 |
+
attention_mask: Optional[torch.Tensor] = None,
|
157 |
+
output_attentions: Optional[bool] = None,
|
158 |
+
output_hidden_states: Optional[bool] = None,
|
159 |
+
return_dict: Optional[bool] = None,
|
160 |
+
labels: Optional[torch.Tensor] = None,
|
161 |
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
162 |
+
r"""
|
163 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
164 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
165 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
166 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
167 |
+
"""
|
168 |
+
|
169 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
170 |
+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
171 |
+
|
172 |
+
outputs = self.wav2vec2(
|
173 |
+
input_values,
|
174 |
+
attention_mask=attention_mask,
|
175 |
+
output_attentions=output_attentions,
|
176 |
+
output_hidden_states=output_hidden_states,
|
177 |
+
return_dict=return_dict,
|
178 |
+
)
|
179 |
+
|
180 |
+
if self.config.use_weighted_layer_sum:
|
181 |
+
hidden_states = outputs[Wav2Vec2ForSequenceClassificationConfig._HIDDEN_STATES_START_POSITION]
|
182 |
+
hidden_states = torch.stack(hidden_states, dim=1)
|
183 |
+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
184 |
+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
185 |
+
else:
|
186 |
+
hidden_states = outputs[0]
|
187 |
+
|
188 |
+
_hidden_states = self.projector(hidden_states)
|
189 |
+
if attention_mask is None:
|
190 |
+
pooled_output = _hidden_states.mean(dim=1)
|
191 |
+
else:
|
192 |
+
padding_mask = self._get_feature_vector_attention_mask(_hidden_states.shape[1], attention_mask)
|
193 |
+
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, _hidden_states.shape[2])
|
194 |
+
_hidden_states[~expand_padding_mask] = 0.0
|
195 |
+
pooled_output = _hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
196 |
+
|
197 |
+
logits = self.classifier(pooled_output)
|
198 |
+
|
199 |
+
loss = None
|
200 |
+
if labels is not None:
|
201 |
+
shared_head = Wav2Vec2ForSequenceClassificationConfig.gen_shared_head(self, attention_mask)
|
202 |
+
criterion = Wav2Vec2ForSequenceClassificationConfig.gen_criterion()
|
203 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_states, logits.view(-1, self.config.num_labels), labels.view(-1), loss_normalization) # NOTE: Apply TRP!
|
204 |
+
|
205 |
+
if not return_dict:
|
206 |
+
output = (logits,) + outputs[Wav2Vec2ForSequenceClassificationConfig._HIDDEN_STATES_START_POSITION:]
|
207 |
+
return ((loss,) + output) if loss is not None else output
|
208 |
+
|
209 |
+
return SequenceClassifierOutput(
|
210 |
+
loss=loss,
|
211 |
+
logits=logits,
|
212 |
+
hidden_states=outputs.hidden_states,
|
213 |
+
attentions=outputs.attentions,
|
214 |
+
)
|
215 |
+
return func
|
216 |
+
|
217 |
+
|
218 |
+
# MobileNetV2 for Image Classification
|
219 |
+
class MobileNetV2Config(Config):
|
220 |
+
@staticmethod
|
221 |
+
def gen_criterion(label_smoothing=0.0, top_k=1):
|
222 |
+
def func(input, target, mask=None):
|
223 |
+
"""
|
224 |
+
Args:
|
225 |
+
input (Tensor): Input tensor of shape [B, C].
|
226 |
+
target (Tensor): Target labels of shape [B] or [B, C].
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
loss (Tensor): Scalar tensor representing the loss.
|
230 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
231 |
+
"""
|
232 |
+
label = torch.argmax(target, dim=1) if label_smoothing > 0.0 else target
|
233 |
+
|
234 |
+
unmasked_loss = F.cross_entropy(input, label, reduction="none", label_smoothing=label_smoothing)
|
235 |
+
if mask is None:
|
236 |
+
mask = torch.ones_like(unmasked_loss, dtype=torch.float32, device=target.device)
|
237 |
+
loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
|
238 |
+
|
239 |
+
with torch.no_grad():
|
240 |
+
topk_values, topk_indices = torch.topk(input, top_k, dim=-1)
|
241 |
+
mask = mask * torch.eq(topk_indices, label[:, None]).any(dim=-1).to(input.dtype)
|
242 |
+
|
243 |
+
return loss, mask
|
244 |
+
return func
|
245 |
+
|
246 |
+
@staticmethod
|
247 |
+
def gen_shared_head(self):
|
248 |
+
def func(x):
|
249 |
+
"""
|
250 |
+
Args:
|
251 |
+
x (Tensor): Hidden States tensor of shape [B, hidden_units].
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
255 |
+
"""
|
256 |
+
logits = self.classifier(x)
|
257 |
+
return logits
|
258 |
+
return func
|
259 |
+
|
260 |
+
@staticmethod
|
261 |
+
def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
|
262 |
+
def func(self, images: Tensor, targets=None):
|
263 |
+
x = self.features(images)
|
264 |
+
x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
|
265 |
+
x = torch.flatten(x, 1)
|
266 |
+
logits = self.classifier(x)
|
267 |
+
|
268 |
+
if self.training:
|
269 |
+
torch._assert(targets is not None, "targets should not be none when in training mode")
|
270 |
+
shared_head = MobileNetV2Config.gen_shared_head(self)
|
271 |
+
criterion = MobileNetV2Config.gen_criterion(label_smoothing, top_k)
|
272 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
|
273 |
+
return logits, loss
|
274 |
+
return logits
|
275 |
+
return func
|
276 |
+
|
277 |
+
|
278 |
+
# ResNet for Image Classification
|
279 |
+
class ResNetConfig(MobileNetV2Config):
|
280 |
+
@staticmethod
|
281 |
+
def gen_shared_head(self):
|
282 |
+
def func(x):
|
283 |
+
"""
|
284 |
+
Args:
|
285 |
+
x (Tensor): Hidden States tensor of shape [B, hidden_units].
|
286 |
+
|
287 |
+
Returns:
|
288 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
289 |
+
"""
|
290 |
+
logits = self.fc(x)
|
291 |
+
return logits
|
292 |
+
return func
|
293 |
+
|
294 |
+
@staticmethod
|
295 |
+
def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
|
296 |
+
def func(self, images: Tensor, targets=None):
|
297 |
+
x = self.conv1(images)
|
298 |
+
x = self.bn1(x)
|
299 |
+
x = self.relu(x)
|
300 |
+
x = self.maxpool(x)
|
301 |
+
|
302 |
+
x = self.layer1(x)
|
303 |
+
x = self.layer2(x)
|
304 |
+
x = self.layer3(x)
|
305 |
+
x = self.layer4(x)
|
306 |
+
|
307 |
+
x = self.avgpool(x)
|
308 |
+
x = torch.flatten(x, 1)
|
309 |
+
logits = self.fc(x)
|
310 |
+
|
311 |
+
if self.training:
|
312 |
+
torch._assert(targets is not None, "targets should not be none when in training mode")
|
313 |
+
shared_head = ResNetConfig.gen_shared_head(self)
|
314 |
+
criterion = ResNetConfig.gen_criterion(label_smoothing, top_k)
|
315 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
|
316 |
+
return logits, loss
|
317 |
+
return logits
|
318 |
+
return func
|
319 |
+
|
320 |
+
|
321 |
+
# EfficientNet for Image Classification
|
322 |
+
class EfficientNetConfig(MobileNetV2Config):
|
323 |
+
@staticmethod
|
324 |
+
def gen_shared_head(self):
|
325 |
+
def func(x):
|
326 |
+
"""
|
327 |
+
Args:
|
328 |
+
x (Tensor): Hidden States tensor of shape [B, hidden_units].
|
329 |
+
|
330 |
+
Returns:
|
331 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
332 |
+
"""
|
333 |
+
logits = self.classifier(x)
|
334 |
+
return logits
|
335 |
+
return func
|
336 |
+
|
337 |
+
@staticmethod
|
338 |
+
def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
|
339 |
+
def func(self, images: Tensor, targets=None):
|
340 |
+
x = self.features(images)
|
341 |
+
x = self.avgpool(x)
|
342 |
+
x = torch.flatten(x, 1)
|
343 |
+
logits = self.classifier(x)
|
344 |
+
|
345 |
+
if self.training:
|
346 |
+
torch._assert(targets is not None, "targets should not be none when in training mode")
|
347 |
+
shared_head = EfficientNetConfig.gen_shared_head(self)
|
348 |
+
criterion = EfficientNetConfig.gen_criterion(label_smoothing, top_k)
|
349 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
|
350 |
+
return logits, loss
|
351 |
+
return logits
|
352 |
+
return func
|
353 |
+
|
354 |
+
|
355 |
+
# ViT for Image Classification
|
356 |
+
class VisionTransformerConfig(MobileNetV2Config):
|
357 |
+
@staticmethod
|
358 |
+
def gen_shared_head(self):
|
359 |
+
def func(x):
|
360 |
+
"""
|
361 |
+
Args:
|
362 |
+
x (Tensor): Hidden States tensor of shape [B, hidden_units].
|
363 |
+
|
364 |
+
Returns:
|
365 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
366 |
+
"""
|
367 |
+
logits = self.heads(x)
|
368 |
+
return logits
|
369 |
+
return func
|
370 |
+
|
371 |
+
@staticmethod
|
372 |
+
def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
|
373 |
+
def func(self, images: Tensor, targets=None):
|
374 |
+
x = self._process_input(images)
|
375 |
+
n = x.shape[0]
|
376 |
+
batch_class_token = self.class_token.expand(n, -1, -1)
|
377 |
+
x = torch.cat([batch_class_token, x], dim=1)
|
378 |
+
x = self.encoder(x)
|
379 |
+
x = x[:, 0]
|
380 |
+
|
381 |
+
logits = self.heads(x)
|
382 |
+
|
383 |
+
if self.training:
|
384 |
+
torch._assert(targets is not None, "targets should not be none when in training mode")
|
385 |
+
shared_head = VisionTransformerConfig.gen_shared_head(self)
|
386 |
+
criterion = VisionTransformerConfig.gen_criterion(label_smoothing, top_k)
|
387 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, x, logits, targets, loss_normalization)
|
388 |
+
return logits, loss
|
389 |
+
return logits
|
390 |
+
return func
|
391 |
+
|
392 |
+
|
393 |
+
# Bert for Question Answering
|
394 |
+
class BertForQuestionAnsweringConfig(Config):
|
395 |
+
@staticmethod
|
396 |
+
def gen_criterion(top_k=1):
|
397 |
+
def func(input, target: List[Tensor], mask=None):
|
398 |
+
"""
|
399 |
+
Args:
|
400 |
+
input (Tensor): Input tensor of shape [B, C, 2].
|
401 |
+
target (List[Tensor]):
|
402 |
+
Start Positions of shape [B].
|
403 |
+
End Positions of shape [B].
|
404 |
+
|
405 |
+
Returns:
|
406 |
+
loss (Tensor): Scalar tensor representing the loss.
|
407 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
408 |
+
"""
|
409 |
+
start_positions, end_positions = target
|
410 |
+
|
411 |
+
if mask is None:
|
412 |
+
mask = torch.ones_like(start_positions, dtype=torch.float32, device=start_positions.device)
|
413 |
+
|
414 |
+
start_logits, end_logits = input.split(1, dim=-1)
|
415 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
416 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
417 |
+
|
418 |
+
# If we are on multi-GPU, split add a dimension
|
419 |
+
if len(start_positions.size()) > 1:
|
420 |
+
start_positions = start_positions.squeeze(-1)
|
421 |
+
if len(end_positions.size()) > 1:
|
422 |
+
end_positions = end_positions.squeeze(-1)
|
423 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
424 |
+
ignored_index = start_logits.size(1)
|
425 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
426 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
427 |
+
|
428 |
+
masked_start_losses = F.cross_entropy(start_logits, start_positions, ignore_index=ignored_index, reduction="none")
|
429 |
+
start_loss = torch.sum(mask * masked_start_losses) / (torch.sum(mask) + 1e-6)
|
430 |
+
masked_end_losses = F.cross_entropy(end_logits, end_positions, ignore_index=ignored_index, reduction="none")
|
431 |
+
end_loss = torch.sum(mask * masked_end_losses) / (torch.sum(mask) + 1e-6)
|
432 |
+
|
433 |
+
with torch.no_grad():
|
434 |
+
topk_values, topk_indices = torch.topk(start_logits, top_k, dim=1)
|
435 |
+
mask = mask * torch.eq(topk_indices, start_positions[:, None]).any(dim=1).to(start_logits.dtype)
|
436 |
+
topk_values, topk_indices = torch.topk(end_logits, top_k, dim=1)
|
437 |
+
mask = mask * torch.eq(topk_indices, end_positions[:, None]).any(dim=1).to(end_logits.dtype)
|
438 |
+
|
439 |
+
return (start_loss + end_loss) / 2, mask
|
440 |
+
return func
|
441 |
+
|
442 |
+
@staticmethod
|
443 |
+
def gen_shared_head(self):
|
444 |
+
def func(hidden_states):
|
445 |
+
"""
|
446 |
+
Args:
|
447 |
+
hidden_states (Tensor): Hidden States of shape [B, C, hidden_units].
|
448 |
+
|
449 |
+
Returns:
|
450 |
+
logits (Tensor): Logits tensor of shape [B, C, 2].
|
451 |
+
"""
|
452 |
+
logits = self.qa_outputs(hidden_states)
|
453 |
+
return logits
|
454 |
+
return func
|
455 |
+
|
456 |
+
@staticmethod
|
457 |
+
def gen_forward(lambdas, loss_normalization=True, top_k=1):
|
458 |
+
def func(
|
459 |
+
self,
|
460 |
+
input_ids: Optional[torch.Tensor] = None,
|
461 |
+
attention_mask: Optional[torch.Tensor] = None,
|
462 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
463 |
+
position_ids: Optional[torch.Tensor] = None,
|
464 |
+
head_mask: Optional[torch.Tensor] = None,
|
465 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
466 |
+
start_positions: Optional[torch.Tensor] = None,
|
467 |
+
end_positions: Optional[torch.Tensor] = None,
|
468 |
+
output_attentions: Optional[bool] = None,
|
469 |
+
output_hidden_states: Optional[bool] = None,
|
470 |
+
return_dict: Optional[bool] = None,
|
471 |
+
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
472 |
+
r"""
|
473 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
474 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
475 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
476 |
+
are not taken into account for computing the loss.
|
477 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
478 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
479 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
480 |
+
are not taken into account for computing the loss.
|
481 |
+
"""
|
482 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
483 |
+
|
484 |
+
outputs = self.bert(
|
485 |
+
input_ids,
|
486 |
+
attention_mask=attention_mask,
|
487 |
+
token_type_ids=token_type_ids,
|
488 |
+
position_ids=position_ids,
|
489 |
+
head_mask=head_mask,
|
490 |
+
inputs_embeds=inputs_embeds,
|
491 |
+
output_attentions=output_attentions,
|
492 |
+
output_hidden_states=output_hidden_states,
|
493 |
+
return_dict=return_dict,
|
494 |
+
)
|
495 |
+
|
496 |
+
sequence_output = outputs[0]
|
497 |
+
|
498 |
+
logits = self.qa_outputs(sequence_output)
|
499 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
500 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
501 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
502 |
+
|
503 |
+
total_loss = None
|
504 |
+
if start_positions is not None and end_positions is not None:
|
505 |
+
shared_head = BertForQuestionAnsweringConfig.gen_shared_head(self)
|
506 |
+
criterion = BertForQuestionAnsweringConfig.gen_criterion()
|
507 |
+
total_loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, sequence_output, logits, [start_positions, end_positions], loss_normalization) # NOTE: Apply TRP!
|
508 |
+
|
509 |
+
if not return_dict:
|
510 |
+
output = (start_logits, end_logits) + outputs[2:]
|
511 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
512 |
+
|
513 |
+
return QuestionAnsweringModelOutput(
|
514 |
+
loss=total_loss,
|
515 |
+
start_logits=start_logits,
|
516 |
+
end_logits=end_logits,
|
517 |
+
hidden_states=outputs.hidden_states,
|
518 |
+
attentions=outputs.attentions,
|
519 |
+
)
|
520 |
+
return func
|
521 |
+
|
522 |
+
|
523 |
+
# FCN for Semantic Segmentation
|
524 |
+
class FCNConfig(Config):
|
525 |
+
@staticmethod
|
526 |
+
def gen_criterion(top_k=1):
|
527 |
+
def func(input, target, mask=None):
|
528 |
+
"""
|
529 |
+
Args:
|
530 |
+
input Tensor: input tensor of shape [B, C, H, W].
|
531 |
+
target (Tensor): Target labels of shape [B, H, W].
|
532 |
+
|
533 |
+
Returns:
|
534 |
+
loss (Tensor): Scalar tensor representing the loss.
|
535 |
+
mask (Tensor): Boolean mask tensor of shape [B, H, W].
|
536 |
+
"""
|
537 |
+
if mask is None:
|
538 |
+
mask = torch.ones_like(target, dtype=torch.float32, device=target.device)
|
539 |
+
|
540 |
+
masked_loss = F.cross_entropy(input, target, ignore_index=255, reduction="none")
|
541 |
+
loss = torch.sum(mask * masked_loss) / (torch.sum(mask) + 1e-6)
|
542 |
+
|
543 |
+
with torch.no_grad():
|
544 |
+
topk_values, topk_indices = torch.topk(input, top_k, dim=1)
|
545 |
+
mask = mask * torch.eq(topk_indices, target[:, None, :, :]).any(dim=1).to(input.dtype)
|
546 |
+
# mask = mask * torch.eq(torch.argmax(x, dim=1), target).to(x.dtype)
|
547 |
+
|
548 |
+
return loss, mask
|
549 |
+
return func
|
550 |
+
|
551 |
+
@staticmethod
|
552 |
+
def gen_out_shared_head(self, input_shape):
|
553 |
+
def func(features):
|
554 |
+
"""
|
555 |
+
Args:
|
556 |
+
features (Tensor): features tensor of shape [B, hidden_units, H, W].
|
557 |
+
|
558 |
+
Returns:
|
559 |
+
result (Tensors): result tensor of shape [B, C, H, W].
|
560 |
+
"""
|
561 |
+
x = self.classifier(features)
|
562 |
+
result = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
563 |
+
return result
|
564 |
+
return func
|
565 |
+
|
566 |
+
@staticmethod
|
567 |
+
def gen_aux_shared_head(self, input_shape):
|
568 |
+
def func(features):
|
569 |
+
"""
|
570 |
+
Args:
|
571 |
+
features (Tensor): features tensor of shape [B, hidden_units, H, W].
|
572 |
+
|
573 |
+
Returns:
|
574 |
+
result (Tensors): result tensor of shape [B, C, H, W].
|
575 |
+
"""
|
576 |
+
x = self.aux_classifier(features)
|
577 |
+
result = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
578 |
+
return result
|
579 |
+
return func
|
580 |
+
|
581 |
+
@staticmethod
|
582 |
+
def gen_forward(lambdas, loss_normalization=True, top_k=1):
|
583 |
+
def func(self, images: Tensor, targets=None):
|
584 |
+
input_shape = images.shape[-2:]
|
585 |
+
# contract: features is a dict of tensors
|
586 |
+
features = self.backbone(images)
|
587 |
+
|
588 |
+
result = OrderedDict()
|
589 |
+
x = features["out"]
|
590 |
+
x = self.classifier(x)
|
591 |
+
x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
592 |
+
result["out"] = x
|
593 |
+
|
594 |
+
if self.aux_classifier is not None:
|
595 |
+
x = features["aux"]
|
596 |
+
x = self.aux_classifier(x)
|
597 |
+
x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
598 |
+
result["aux"] = x
|
599 |
+
|
600 |
+
if self.training:
|
601 |
+
torch._assert(targets is not None, "targets should not be none when in training mode")
|
602 |
+
out_shared_head = FCNConfig.gen_out_shared_head(self, input_shape)
|
603 |
+
aux_shared_head = FCNConfig.gen_aux_shared_head(self, input_shape)
|
604 |
+
criterion = FCNConfig.gen_criterion(top_k)
|
605 |
+
out_loss = trp_criterion(self.out_trp_blocks, out_shared_head, criterion, lambdas, features["out"], result["out"], targets, loss_normalization)
|
606 |
+
aux_loss = trp_criterion(self.aux_trp_blocks, aux_shared_head, criterion, lambdas, features["aux"], result["aux"], targets, loss_normalization)
|
607 |
+
loss = out_loss + 0.5 * aux_loss
|
608 |
+
return result, loss
|
609 |
+
return result
|
610 |
+
return func
|
611 |
+
|
612 |
+
|
613 |
+
# DeepLabV3Config for Semantic Segmentation
|
614 |
+
class DeepLabV3Config(FCNConfig):
|
615 |
+
pass
|
616 |
+
|
617 |
+
|
618 |
+
# Bert for Text Classification
|
619 |
+
class BertForSequenceClassificationConfig(Config):
|
620 |
+
@staticmethod
|
621 |
+
def gen_criterion():
|
622 |
+
def func(input, target, mask=None):
|
623 |
+
"""
|
624 |
+
Args:
|
625 |
+
input (Tensor): Input tensor of shape [B, C].
|
626 |
+
target (Tensor): Target labels of shape [B].
|
627 |
+
|
628 |
+
Returns:
|
629 |
+
loss (Tensor): Scalar tensor representing the loss.
|
630 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
631 |
+
"""
|
632 |
+
if mask is None:
|
633 |
+
mask = torch.ones_like(target, dtype=torch.float32, device=target.device)
|
634 |
+
|
635 |
+
unmasked_loss = F.cross_entropy(input, target, reduction="none")
|
636 |
+
loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
|
637 |
+
|
638 |
+
with torch.no_grad():
|
639 |
+
mask = mask * torch.eq(torch.argmax(input, dim=1), target).to(input.dtype)
|
640 |
+
|
641 |
+
return loss, mask
|
642 |
+
return func
|
643 |
+
|
644 |
+
@staticmethod
|
645 |
+
def gen_shared_head(self):
|
646 |
+
def func(hidden_states):
|
647 |
+
"""
|
648 |
+
Args:
|
649 |
+
hidden_states (Tensor): Hidden States of shape [B, hidden_units].
|
650 |
+
|
651 |
+
Returns:
|
652 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
653 |
+
"""
|
654 |
+
logits = self.classifier(hidden_states)
|
655 |
+
return logits
|
656 |
+
return func
|
657 |
+
|
658 |
+
@staticmethod
|
659 |
+
def gen_forward(lambdas, loss_normalization=False):
|
660 |
+
def func(
|
661 |
+
self,
|
662 |
+
input_ids: Optional[torch.Tensor] = None,
|
663 |
+
attention_mask: Optional[torch.Tensor] = None,
|
664 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
665 |
+
position_ids: Optional[torch.Tensor] = None,
|
666 |
+
head_mask: Optional[torch.Tensor] = None,
|
667 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
668 |
+
labels: Optional[torch.Tensor] = None,
|
669 |
+
output_attentions: Optional[bool] = None,
|
670 |
+
output_hidden_states: Optional[bool] = None,
|
671 |
+
return_dict: Optional[bool] = None,
|
672 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
673 |
+
r"""
|
674 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
675 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
676 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
677 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
678 |
+
"""
|
679 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
680 |
+
|
681 |
+
outputs = self.bert(
|
682 |
+
input_ids,
|
683 |
+
attention_mask=attention_mask,
|
684 |
+
token_type_ids=token_type_ids,
|
685 |
+
position_ids=position_ids,
|
686 |
+
head_mask=head_mask,
|
687 |
+
inputs_embeds=inputs_embeds,
|
688 |
+
output_attentions=output_attentions,
|
689 |
+
output_hidden_states=output_hidden_states,
|
690 |
+
return_dict=return_dict,
|
691 |
+
)
|
692 |
+
|
693 |
+
pooled_output = outputs[1]
|
694 |
+
|
695 |
+
pooled_output = self.dropout(pooled_output)
|
696 |
+
logits = self.classifier(pooled_output)
|
697 |
+
|
698 |
+
loss = None
|
699 |
+
if labels is not None:
|
700 |
+
assert self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int) # TODO: remove this
|
701 |
+
if self.config.problem_type is None:
|
702 |
+
if self.num_labels == 1:
|
703 |
+
self.config.problem_type = "regression"
|
704 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
705 |
+
self.config.problem_type = "single_label_classification"
|
706 |
+
else:
|
707 |
+
self.config.problem_type = "multi_label_classification"
|
708 |
+
|
709 |
+
if self.config.problem_type == "regression":
|
710 |
+
if self.num_labels == 1:
|
711 |
+
loss = F.mse_loss(logits.squeeze(), labels.squeeze())
|
712 |
+
else:
|
713 |
+
loss = F.mse_loss(logits, labels)
|
714 |
+
elif self.config.problem_type == "single_label_classification":
|
715 |
+
shared_head = BertForSequenceClassificationConfig.gen_shared_head(self)
|
716 |
+
criterion = BertForSequenceClassificationConfig.gen_criterion()
|
717 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, pooled_output, logits, labels, loss_normalization)
|
718 |
+
elif self.config.problem_type == "multi_label_classification":
|
719 |
+
loss = F.binary_cross_entropy_with_logits(logits, labels)
|
720 |
+
if not return_dict:
|
721 |
+
output = (logits,) + outputs[2:]
|
722 |
+
return ((loss,) + output) if loss is not None else output
|
723 |
+
|
724 |
+
return SequenceClassifierOutput(
|
725 |
+
loss=loss,
|
726 |
+
logits=logits,
|
727 |
+
hidden_states=outputs.hidden_states,
|
728 |
+
attentions=outputs.attentions,
|
729 |
+
)
|
730 |
+
return func
|
731 |
+
|
732 |
+
|
733 |
+
# Boberta for Text Classification
|
734 |
+
class RobertaForSequenceClassificationConfig(BertForSequenceClassificationConfig):
|
735 |
+
@staticmethod
|
736 |
+
def gen_shared_head(self):
|
737 |
+
def func(hidden_states):
|
738 |
+
"""
|
739 |
+
Args:
|
740 |
+
hidden_states (Tensor): Hidden States of shape [B, hidden_units].
|
741 |
+
|
742 |
+
Returns:
|
743 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
744 |
+
"""
|
745 |
+
logits = self.classifier(hidden_states)
|
746 |
+
return logits
|
747 |
+
return func
|
748 |
+
|
749 |
+
@staticmethod
|
750 |
+
def gen_forward(lambdas, loss_normalization=False):
|
751 |
+
def func(
|
752 |
+
self,
|
753 |
+
input_ids: Optional[torch.LongTensor] = None,
|
754 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
755 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
756 |
+
position_ids: Optional[torch.LongTensor] = None,
|
757 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
758 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
759 |
+
labels: Optional[torch.LongTensor] = None,
|
760 |
+
output_attentions: Optional[bool] = None,
|
761 |
+
output_hidden_states: Optional[bool] = None,
|
762 |
+
return_dict: Optional[bool] = None,
|
763 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
764 |
+
r"""
|
765 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
766 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
767 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
768 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
769 |
+
"""
|
770 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
771 |
+
|
772 |
+
outputs = self.roberta(
|
773 |
+
input_ids,
|
774 |
+
attention_mask=attention_mask,
|
775 |
+
token_type_ids=token_type_ids,
|
776 |
+
position_ids=position_ids,
|
777 |
+
head_mask=head_mask,
|
778 |
+
inputs_embeds=inputs_embeds,
|
779 |
+
output_attentions=output_attentions,
|
780 |
+
output_hidden_states=output_hidden_states,
|
781 |
+
return_dict=return_dict,
|
782 |
+
)
|
783 |
+
sequence_output = outputs[0]
|
784 |
+
logits = self.classifier(sequence_output)
|
785 |
+
|
786 |
+
loss = None
|
787 |
+
if labels is not None:
|
788 |
+
assert self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int) # TODO: remove this
|
789 |
+
# move labels to correct device to enable model parallelism
|
790 |
+
labels = labels.to(logits.device)
|
791 |
+
if self.config.problem_type is None:
|
792 |
+
if self.num_labels == 1:
|
793 |
+
self.config.problem_type = "regression"
|
794 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
795 |
+
self.config.problem_type = "single_label_classification"
|
796 |
+
else:
|
797 |
+
self.config.problem_type = "multi_label_classification"
|
798 |
+
|
799 |
+
if self.config.problem_type == "regression":
|
800 |
+
if self.num_labels == 1:
|
801 |
+
loss = F.mse_loss(logits.squeeze(), labels.squeeze())
|
802 |
+
else:
|
803 |
+
loss = F.mse_loss(logits, labels)
|
804 |
+
elif self.config.problem_type == "single_label_classification":
|
805 |
+
shared_head = BertForSequenceClassificationConfig.gen_shared_head(self)
|
806 |
+
criterion = BertForSequenceClassificationConfig.gen_criterion()
|
807 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, sequence_output, logits, labels, loss_normalization)
|
808 |
+
elif self.config.problem_type == "multi_label_classification":
|
809 |
+
loss = F.binary_cross_entropy_with_logits(logits, labels)
|
810 |
+
|
811 |
+
if not return_dict:
|
812 |
+
output = (logits,) + outputs[2:]
|
813 |
+
return ((loss,) + output) if loss is not None else output
|
814 |
+
|
815 |
+
return SequenceClassifierOutput(
|
816 |
+
loss=loss,
|
817 |
+
logits=logits,
|
818 |
+
hidden_states=outputs.hidden_states,
|
819 |
+
attentions=outputs.attentions,
|
820 |
+
)
|
821 |
+
return func
|
822 |
+
|
823 |
+
|
824 |
+
# Wav2Vec2 for Speech Recognition
|
825 |
+
class Wav2Vec2ForCTCConfig(Config):
|
826 |
+
_HIDDEN_STATES_START_POSITION = 2
|
827 |
+
|
828 |
+
@staticmethod
|
829 |
+
def greedy_decode_ctc(
|
830 |
+
log_probs: torch.Tensor,
|
831 |
+
input_lengths: torch.Tensor,
|
832 |
+
blank_token_id: int,
|
833 |
+
target_lengths: torch.Tensor
|
834 |
+
):
|
835 |
+
"""
|
836 |
+
Convert logits to flattened predictions that match the shape of flattened_targets.
|
837 |
+
|
838 |
+
Args:
|
839 |
+
log_probs: [B, L, V] - log-softmax output
|
840 |
+
input_lengths: [B] - actual length of each input
|
841 |
+
blank_token_id: int - index of blank token
|
842 |
+
target_lengths: [B] - used to determine how many predictions to keep per sample
|
843 |
+
|
844 |
+
Returns:
|
845 |
+
flattened_predictions: 1D tensor, same total length as sum(target_lengths)
|
846 |
+
"""
|
847 |
+
batch_size = log_probs.size(0)
|
848 |
+
decoded_all = []
|
849 |
+
|
850 |
+
predicted_ids = log_probs.argmax(dim=-1) # [B, L]
|
851 |
+
|
852 |
+
for i in range(batch_size):
|
853 |
+
pred = predicted_ids[i][:input_lengths[i]] # [Li]
|
854 |
+
prev = None
|
855 |
+
decoded = []
|
856 |
+
for token in pred:
|
857 |
+
token = token.item()
|
858 |
+
if token != blank_token_id and token != prev:
|
859 |
+
decoded.append(token)
|
860 |
+
prev = token
|
861 |
+
# Trim or pad to match target_lengths[i]
|
862 |
+
tgt_len = target_lengths[i].item()
|
863 |
+
if len(decoded) >= tgt_len:
|
864 |
+
decoded = decoded[:tgt_len]
|
865 |
+
else:
|
866 |
+
decoded = decoded + [blank_token_id] * (tgt_len - len(decoded)) # pad with blank
|
867 |
+
decoded_all.extend(decoded)
|
868 |
+
|
869 |
+
return torch.tensor(decoded_all, dtype=torch.long, device=log_probs.device) # shape: [sum(target_lengths)]
|
870 |
+
|
871 |
+
@staticmethod
|
872 |
+
def gen_criterion(input_lengths: Tensor, pad_token_id: int, ctc_zero_infinity: bool):
|
873 |
+
def func(logits: Tensor, labels: Tensor, mask=None):
|
874 |
+
"""
|
875 |
+
Args:
|
876 |
+
logits (Tensor): Log Probablities of shape [B, L, V].
|
877 |
+
labels (Tensor): Flattened Targets of shape [B, L'].
|
878 |
+
|
879 |
+
Returns:
|
880 |
+
loss (Tensor): Scalar tensor representing the loss.
|
881 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
882 |
+
"""
|
883 |
+
if mask is None:
|
884 |
+
mask = torch.ones_like(input_lengths, dtype=torch.float32, device=input_lengths.device)
|
885 |
+
|
886 |
+
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
887 |
+
labels_mask = labels >= 0
|
888 |
+
target_lengths = labels_mask.sum(-1)
|
889 |
+
flattened_targets = labels.masked_select(labels_mask)
|
890 |
+
with torch.backends.cudnn.flags(enabled=False):
|
891 |
+
masked_losses = nn.functional.ctc_loss(log_probs, flattened_targets, input_lengths, target_lengths, blank=pad_token_id, reduction="none", zero_infinity=ctc_zero_infinity)
|
892 |
+
loss = torch.sum(mask * masked_losses) / (torch.sum(mask) + 1e-6)
|
893 |
+
|
894 |
+
with torch.no_grad():
|
895 |
+
thres = 0.5
|
896 |
+
flattened_predictions = Wav2Vec2ForCTCConfig.greedy_decode_ctc(
|
897 |
+
log_probs.transpose(0, 1), # [B, T, V]
|
898 |
+
input_lengths=input_lengths,
|
899 |
+
blank_token_id=pad_token_id,
|
900 |
+
target_lengths=target_lengths
|
901 |
+
)
|
902 |
+
token_wise_mask = torch.eq(flattened_predictions, flattened_targets).to(flattened_targets.dtype)
|
903 |
+
segment_ids = torch.arange(len(target_lengths), device=target_lengths.device).repeat_interleave(target_lengths)
|
904 |
+
sequence_wise_mask = torch.zeros(len(target_lengths), dtype=target_lengths.dtype, device=token_wise_mask.device).scatter_add(0, segment_ids, token_wise_mask)
|
905 |
+
mask = mask * torch.ge(sequence_wise_mask, thres * target_lengths).to(flattened_targets.dtype)
|
906 |
+
|
907 |
+
return loss, mask
|
908 |
+
return func
|
909 |
+
|
910 |
+
@staticmethod
|
911 |
+
def gen_shared_head(self):
|
912 |
+
def func(hidden_states):
|
913 |
+
"""
|
914 |
+
Args:
|
915 |
+
hidden_states (Tensor): Hidden States of shape [B, C, hidden_units].
|
916 |
+
|
917 |
+
Returns:
|
918 |
+
logits (Tensor): Logits tensor of shape [B, C, 2].
|
919 |
+
"""
|
920 |
+
logits = self.lm_head(hidden_states)
|
921 |
+
# log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
922 |
+
return logits
|
923 |
+
return func
|
924 |
+
|
925 |
+
@staticmethod
|
926 |
+
def gen_forward(lambdas, loss_normalization=False):
|
927 |
+
def func(
|
928 |
+
self,
|
929 |
+
input_values: Optional[torch.Tensor],
|
930 |
+
attention_mask: Optional[torch.Tensor] = None,
|
931 |
+
output_attentions: Optional[bool] = None,
|
932 |
+
output_hidden_states: Optional[bool] = None,
|
933 |
+
return_dict: Optional[bool] = None,
|
934 |
+
labels: Optional[torch.Tensor] = None,
|
935 |
+
) -> Union[Tuple, CausalLMOutput]:
|
936 |
+
r"""
|
937 |
+
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
|
938 |
+
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
|
939 |
+
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
|
940 |
+
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
941 |
+
config.vocab_size - 1]`.
|
942 |
+
"""
|
943 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
944 |
+
|
945 |
+
if labels is not None and labels.max() >= self.config.vocab_size:
|
946 |
+
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
947 |
+
|
948 |
+
outputs = self.wav2vec2(
|
949 |
+
input_values,
|
950 |
+
attention_mask=attention_mask,
|
951 |
+
output_attentions=output_attentions,
|
952 |
+
output_hidden_states=output_hidden_states,
|
953 |
+
return_dict=return_dict,
|
954 |
+
)
|
955 |
+
|
956 |
+
hidden_states = outputs[0]
|
957 |
+
hidden_states = self.dropout(hidden_states)
|
958 |
+
|
959 |
+
logits = self.lm_head(hidden_states)
|
960 |
+
|
961 |
+
loss = None
|
962 |
+
if labels is not None:
|
963 |
+
# retrieve loss input_lengths from attention_mask
|
964 |
+
attention_mask = (
|
965 |
+
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
966 |
+
)
|
967 |
+
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
968 |
+
shared_head = Wav2Vec2ForCTCConfig.gen_shared_head(self)
|
969 |
+
criterion = Wav2Vec2ForCTCConfig.gen_criterion(input_lengths, self.config.pad_token_id, self.config.ctc_zero_infinity)
|
970 |
+
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_states, logits, labels, loss_normalization) # NOTE: Apply TRP!
|
971 |
+
|
972 |
+
if not return_dict:
|
973 |
+
output = (logits,) + outputs[Wav2Vec2ForCTCConfig._HIDDEN_STATES_START_POSITION:]
|
974 |
+
return ((loss,) + output) if loss is not None else output
|
975 |
+
|
976 |
+
return CausalLMOutput(
|
977 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
978 |
+
)
|
979 |
+
return func
|
980 |
+
|
981 |
+
|
982 |
+
# MBart for Translation
|
983 |
+
class MBartForConditionalGenerationConfig(Config):
|
984 |
+
@staticmethod
|
985 |
+
def gen_criterion(vocab_size: int, top_k=1):
|
986 |
+
def func(logits, labels, mask=None):
|
987 |
+
"""
|
988 |
+
Args:
|
989 |
+
logits (Tensor): Logits tensor of shape [B, L, V].
|
990 |
+
labels (Tensor): Target labels of shape [B, L].
|
991 |
+
|
992 |
+
Returns:
|
993 |
+
loss (Tensor): Scalar tensor representing the loss.
|
994 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
995 |
+
"""
|
996 |
+
if mask is None:
|
997 |
+
mask = torch.ones_like(labels.view(-1), dtype=torch.float32, device=labels.device)
|
998 |
+
|
999 |
+
masked_losses = F.cross_entropy(logits.view(-1, vocab_size), labels.view(-1), reduction="none")
|
1000 |
+
loss = torch.sum(mask * masked_losses) / (torch.sum(mask) + 1e-6)
|
1001 |
+
|
1002 |
+
with torch.no_grad():
|
1003 |
+
topk_values, topk_indices = torch.topk(logits.view(-1, vocab_size), top_k, dim=1)
|
1004 |
+
mask = mask * torch.eq(topk_indices, labels.view(-1, 1)).any(dim=1).to(logits.dtype)
|
1005 |
+
|
1006 |
+
return loss, mask
|
1007 |
+
return func
|
1008 |
+
|
1009 |
+
@staticmethod
|
1010 |
+
def gen_shared_head(self):
|
1011 |
+
def func(hidden_states):
|
1012 |
+
"""
|
1013 |
+
Args:
|
1014 |
+
hidden_states (Tensor): Hidden States of shape [B, L, hidden_units].
|
1015 |
+
|
1016 |
+
Returns:
|
1017 |
+
logits (Tensor): Logits tensor of shape [B, L].
|
1018 |
+
"""
|
1019 |
+
logits = self.lm_head(hidden_states) + self.final_logits_bias
|
1020 |
+
return logits
|
1021 |
+
return func
|
1022 |
+
|
1023 |
+
@staticmethod
|
1024 |
+
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
|
1025 |
+
"""
|
1026 |
+
Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
|
1027 |
+
have a single `decoder_start_token_id` in contrast to other Bart-like models.
|
1028 |
+
"""
|
1029 |
+
prev_output_tokens = input_ids.clone()
|
1030 |
+
|
1031 |
+
if pad_token_id is None:
|
1032 |
+
raise ValueError("self.model.config.pad_token_id has to be defined.")
|
1033 |
+
# replace possible -100 values in labels by `pad_token_id`
|
1034 |
+
prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)
|
1035 |
+
|
1036 |
+
index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
|
1037 |
+
decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
|
1038 |
+
prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
|
1039 |
+
prev_output_tokens[:, 0] = decoder_start_tokens
|
1040 |
+
|
1041 |
+
return prev_output_tokens
|
1042 |
+
|
1043 |
+
@staticmethod
|
1044 |
+
def gen_forward(lambdas, loss_normalization=False):
|
1045 |
+
def func(
|
1046 |
+
self,
|
1047 |
+
input_ids: torch.LongTensor = None,
|
1048 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1049 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
1050 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
1051 |
+
head_mask: Optional[torch.Tensor] = None,
|
1052 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
1053 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
1054 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1055 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1056 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1057 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
1058 |
+
labels: Optional[torch.LongTensor] = None,
|
1059 |
+
use_cache: Optional[bool] = None,
|
1060 |
+
output_attentions: Optional[bool] = None,
|
1061 |
+
output_hidden_states: Optional[bool] = None,
|
1062 |
+
return_dict: Optional[bool] = None,
|
1063 |
+
) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
|
1064 |
+
r"""
|
1065 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1066 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1067 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1068 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1069 |
+
|
1070 |
+
Returns:
|
1071 |
+
|
1072 |
+
"""
|
1073 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1074 |
+
|
1075 |
+
if labels is not None:
|
1076 |
+
# if use_cache:
|
1077 |
+
# logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
1078 |
+
use_cache = False
|
1079 |
+
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
1080 |
+
decoder_input_ids = MBartForConditionalGenerationConfig.shift_tokens_right(labels, self.config.pad_token_id)
|
1081 |
+
|
1082 |
+
outputs = self.model(
|
1083 |
+
input_ids,
|
1084 |
+
attention_mask=attention_mask,
|
1085 |
+
decoder_input_ids=decoder_input_ids,
|
1086 |
+
encoder_outputs=encoder_outputs,
|
1087 |
+
decoder_attention_mask=decoder_attention_mask,
|
1088 |
+
head_mask=head_mask,
|
1089 |
+
decoder_head_mask=decoder_head_mask,
|
1090 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
1091 |
+
past_key_values=past_key_values,
|
1092 |
+
inputs_embeds=inputs_embeds,
|
1093 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
1094 |
+
use_cache=use_cache,
|
1095 |
+
output_attentions=output_attentions,
|
1096 |
+
output_hidden_states=output_hidden_states,
|
1097 |
+
return_dict=return_dict,
|
1098 |
+
)
|
1099 |
+
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
|
1100 |
+
|
1101 |
+
masked_lm_loss = None
|
1102 |
+
if labels is not None:
|
1103 |
+
shared_head = MBartForConditionalGenerationConfig.gen_shared_head(self)
|
1104 |
+
criterion = MBartForConditionalGenerationConfig.gen_criterion(self.config.vocab_size)
|
1105 |
+
masked_lm_loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, outputs[0], lm_logits, labels, loss_normalization)
|
1106 |
+
|
1107 |
+
if not return_dict:
|
1108 |
+
output = (lm_logits,) + outputs[1:]
|
1109 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
1110 |
+
|
1111 |
+
return Seq2SeqLMOutput(
|
1112 |
+
loss=masked_lm_loss,
|
1113 |
+
logits=lm_logits,
|
1114 |
+
past_key_values=outputs.past_key_values,
|
1115 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
1116 |
+
decoder_attentions=outputs.decoder_attentions,
|
1117 |
+
cross_attentions=outputs.cross_attentions,
|
1118 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
1119 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
1120 |
+
encoder_attentions=outputs.encoder_attentions,
|
1121 |
+
)
|
1122 |
+
return func
|
1123 |
+
|
1124 |
+
|
1125 |
+
def apply_trp(model, depths: int, p: float, lambdas: List[float], **kwargs):
|
1126 |
+
if isinstance(model, transformers.Wav2Vec2ForSequenceClassification):
|
1127 |
+
print("✅ Applying TRP to Wav2Vec2 for Audio Classification...")
|
1128 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
|
1129 |
+
model.forward = types.MethodType(Wav2Vec2ForSequenceClassificationConfig.gen_forward(lambdas, False), model)
|
1130 |
+
elif isinstance(model, MobileNetV2):
|
1131 |
+
print("✅ Applying TRP to MobileNetV2 for Image Classification...")
|
1132 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1280, p) for _ in lambdas])
|
1133 |
+
model.forward = types.MethodType(MobileNetV2Config.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
|
1134 |
+
elif isinstance(model, ResNet):
|
1135 |
+
print("✅ Applying TRP to ResNet for Image Classification...")
|
1136 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 2048, p) for _ in lambdas])
|
1137 |
+
model.forward = types.MethodType(ResNetConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
|
1138 |
+
elif isinstance(model, EfficientNet):
|
1139 |
+
print("✅ Applying TRP to EfficientNet for Image Classification...")
|
1140 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1280, p) for _ in lambdas])
|
1141 |
+
model.forward = types.MethodType(EfficientNetConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
|
1142 |
+
elif isinstance(model, VisionTransformer):
|
1143 |
+
print("✅ Applying TRP to VisionTransformer for Image Classification...")
|
1144 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
|
1145 |
+
model.forward = types.MethodType(VisionTransformerConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
|
1146 |
+
elif isinstance(model, transformers.BertForQuestionAnswering):
|
1147 |
+
print("✅ Applying TRP to Bert for Question Answering...")
|
1148 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
|
1149 |
+
model.forward = types.MethodType(BertForQuestionAnsweringConfig.gen_forward(lambdas, True, 1), model)
|
1150 |
+
elif isinstance(model, FCN):
|
1151 |
+
print("✅ Applying TRP to FCN for Semantic Segmentation...")
|
1152 |
+
model.out_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 2048, p, dim=1) for _ in lambdas])
|
1153 |
+
model.aux_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p, dim=1) for _ in lambdas])
|
1154 |
+
model.forward = types.MethodType(FCNConfig.gen_forward(lambdas, True, 1), model)
|
1155 |
+
elif isinstance(model, DeepLabV3):
|
1156 |
+
print("✅ Applying TRP to DeepLabV3 for Semantic Segmentation...")
|
1157 |
+
model.out_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 2048, p, dim=1) for _ in lambdas])
|
1158 |
+
model.aux_trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p, dim=1) for _ in lambdas])
|
1159 |
+
model.forward = types.MethodType(DeepLabV3Config.gen_forward(lambdas, True, 1), model)
|
1160 |
+
elif isinstance(model, transformers.BertForSequenceClassification):
|
1161 |
+
print("✅ Applying TRP to Bert for Text Classification...")
|
1162 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
|
1163 |
+
model.forward = types.MethodType(BertForSequenceClassificationConfig.gen_forward(lambdas, False), model)
|
1164 |
+
elif isinstance(model, transformers.RobertaForSequenceClassification):
|
1165 |
+
print("✅ Applying TRP to Roberta for Text Classification...")
|
1166 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 768, p) for _ in lambdas])
|
1167 |
+
model.forward = types.MethodType(RobertaForSequenceClassificationConfig.gen_forward(lambdas, False), model)
|
1168 |
+
elif isinstance(model, transformers.Wav2Vec2ForCTC):
|
1169 |
+
print("✅ Applying TRP to Wav2Vec2 for Speech Recognition...")
|
1170 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p) for _ in lambdas])
|
1171 |
+
model.forward = types.MethodType(Wav2Vec2ForCTCConfig.gen_forward(lambdas, False), model)
|
1172 |
+
elif isinstance(model, transformers.MBartForConditionalGeneration):
|
1173 |
+
print("✅ Applying TRP to MBart for Translation...")
|
1174 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths, 1024, p) for _ in lambdas])
|
1175 |
+
model.forward = types.MethodType(MBartForConditionalGenerationConfig.gen_forward(lambdas, False), model)
|
1176 |
+
else:
|
1177 |
+
torch._assert(
|
1178 |
+
isinstance(model, transformers.Wav2Vec2ForSequenceClassification),
|
1179 |
+
"The model should be an object of [`Wav2Vec2ForSequenceClassification`].")
|
1180 |
+
|
1181 |
+
return model
|
hpo-examples/image-classification/utils.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import datetime
|
3 |
+
import errno
|
4 |
+
import hashlib
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
from collections import defaultdict, deque, OrderedDict
|
8 |
+
from typing import List, Optional, Tuple
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.distributed as dist
|
12 |
+
|
13 |
+
|
14 |
+
class SmoothedValue:
|
15 |
+
"""Track a series of values and provide access to smoothed values over a
|
16 |
+
window or the global series average.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, window_size=20, fmt=None):
|
20 |
+
if fmt is None:
|
21 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
22 |
+
self.deque = deque(maxlen=window_size)
|
23 |
+
self.total = 0.0
|
24 |
+
self.count = 0
|
25 |
+
self.fmt = fmt
|
26 |
+
|
27 |
+
def update(self, value, n=1):
|
28 |
+
self.deque.append(value)
|
29 |
+
self.count += n
|
30 |
+
self.total += value * n
|
31 |
+
|
32 |
+
def synchronize_between_processes(self):
|
33 |
+
"""
|
34 |
+
Warning: does not synchronize the deque!
|
35 |
+
"""
|
36 |
+
t = reduce_across_processes([self.count, self.total])
|
37 |
+
t = t.tolist()
|
38 |
+
self.count = int(t[0])
|
39 |
+
self.total = t[1]
|
40 |
+
|
41 |
+
@property
|
42 |
+
def median(self):
|
43 |
+
d = torch.tensor(list(self.deque))
|
44 |
+
return d.median().item()
|
45 |
+
|
46 |
+
@property
|
47 |
+
def avg(self):
|
48 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
49 |
+
return d.mean().item()
|
50 |
+
|
51 |
+
@property
|
52 |
+
def global_avg(self):
|
53 |
+
return self.total / self.count
|
54 |
+
|
55 |
+
@property
|
56 |
+
def max(self):
|
57 |
+
return max(self.deque)
|
58 |
+
|
59 |
+
@property
|
60 |
+
def value(self):
|
61 |
+
return self.deque[-1]
|
62 |
+
|
63 |
+
def __str__(self):
|
64 |
+
return self.fmt.format(
|
65 |
+
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
66 |
+
)
|
67 |
+
|
68 |
+
|
69 |
+
class MetricLogger:
|
70 |
+
def __init__(self, delimiter="\t"):
|
71 |
+
self.meters = defaultdict(SmoothedValue)
|
72 |
+
self.delimiter = delimiter
|
73 |
+
|
74 |
+
def update(self, **kwargs):
|
75 |
+
for k, v in kwargs.items():
|
76 |
+
if isinstance(v, torch.Tensor):
|
77 |
+
v = v.item()
|
78 |
+
assert isinstance(v, (float, int))
|
79 |
+
self.meters[k].update(v)
|
80 |
+
|
81 |
+
def __getattr__(self, attr):
|
82 |
+
if attr in self.meters:
|
83 |
+
return self.meters[attr]
|
84 |
+
if attr in self.__dict__:
|
85 |
+
return self.__dict__[attr]
|
86 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
|
87 |
+
|
88 |
+
def __str__(self):
|
89 |
+
loss_str = []
|
90 |
+
for name, meter in self.meters.items():
|
91 |
+
loss_str.append(f"{name}: {str(meter)}")
|
92 |
+
return self.delimiter.join(loss_str)
|
93 |
+
|
94 |
+
def synchronize_between_processes(self):
|
95 |
+
for meter in self.meters.values():
|
96 |
+
meter.synchronize_between_processes()
|
97 |
+
|
98 |
+
def add_meter(self, name, meter):
|
99 |
+
self.meters[name] = meter
|
100 |
+
|
101 |
+
def log_every(self, iterable, print_freq, header=None):
|
102 |
+
i = 0
|
103 |
+
if not header:
|
104 |
+
header = ""
|
105 |
+
start_time = time.time()
|
106 |
+
end = time.time()
|
107 |
+
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
108 |
+
data_time = SmoothedValue(fmt="{avg:.4f}")
|
109 |
+
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
110 |
+
if torch.cuda.is_available():
|
111 |
+
log_msg = self.delimiter.join(
|
112 |
+
[
|
113 |
+
header,
|
114 |
+
"[{0" + space_fmt + "}/{1}]",
|
115 |
+
"eta: {eta}",
|
116 |
+
"{meters}",
|
117 |
+
"time: {time}",
|
118 |
+
"data: {data}",
|
119 |
+
"max mem: {memory:.0f}",
|
120 |
+
]
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
log_msg = self.delimiter.join(
|
124 |
+
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
|
125 |
+
)
|
126 |
+
MB = 1024.0 * 1024.0
|
127 |
+
for obj in iterable:
|
128 |
+
data_time.update(time.time() - end)
|
129 |
+
yield obj
|
130 |
+
iter_time.update(time.time() - end)
|
131 |
+
if i % print_freq == 0:
|
132 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
133 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
134 |
+
if torch.cuda.is_available():
|
135 |
+
print(
|
136 |
+
log_msg.format(
|
137 |
+
i,
|
138 |
+
len(iterable),
|
139 |
+
eta=eta_string,
|
140 |
+
meters=str(self),
|
141 |
+
time=str(iter_time),
|
142 |
+
data=str(data_time),
|
143 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
144 |
+
)
|
145 |
+
)
|
146 |
+
else:
|
147 |
+
print(
|
148 |
+
log_msg.format(
|
149 |
+
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
|
150 |
+
)
|
151 |
+
)
|
152 |
+
i += 1
|
153 |
+
end = time.time()
|
154 |
+
total_time = time.time() - start_time
|
155 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
156 |
+
print(f"{header} Total time: {total_time_str}")
|
157 |
+
|
158 |
+
|
159 |
+
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
|
160 |
+
"""Maintains moving averages of model parameters using an exponential decay.
|
161 |
+
``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
|
162 |
+
`torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
|
163 |
+
is used to compute the EMA.
|
164 |
+
"""
|
165 |
+
|
166 |
+
def __init__(self, model, decay, device="cpu"):
|
167 |
+
def ema_avg(avg_model_param, model_param, num_averaged):
|
168 |
+
return decay * avg_model_param + (1 - decay) * model_param
|
169 |
+
|
170 |
+
super().__init__(model, device, ema_avg, use_buffers=True)
|
171 |
+
|
172 |
+
|
173 |
+
def accuracy(output, target, topk=(1,)):
|
174 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
175 |
+
with torch.inference_mode():
|
176 |
+
maxk = max(topk)
|
177 |
+
batch_size = target.size(0)
|
178 |
+
if target.ndim == 2:
|
179 |
+
target = target.max(dim=1)[1]
|
180 |
+
|
181 |
+
_, pred = output.topk(maxk, 1, True, True)
|
182 |
+
pred = pred.t()
|
183 |
+
correct = pred.eq(target[None])
|
184 |
+
|
185 |
+
res = []
|
186 |
+
for k in topk:
|
187 |
+
correct_k = correct[:k].flatten().sum(dtype=torch.float32)
|
188 |
+
res.append(correct_k * (100.0 / batch_size))
|
189 |
+
return res
|
190 |
+
|
191 |
+
|
192 |
+
def mkdir(path):
|
193 |
+
try:
|
194 |
+
os.makedirs(path)
|
195 |
+
except OSError as e:
|
196 |
+
if e.errno != errno.EEXIST:
|
197 |
+
raise
|
198 |
+
|
199 |
+
|
200 |
+
def setup_for_distributed(is_master):
|
201 |
+
"""
|
202 |
+
This function disables printing when not in master process
|
203 |
+
"""
|
204 |
+
import builtins as __builtin__
|
205 |
+
|
206 |
+
builtin_print = __builtin__.print
|
207 |
+
|
208 |
+
def print(*args, **kwargs):
|
209 |
+
force = kwargs.pop("force", False)
|
210 |
+
if is_master or force:
|
211 |
+
builtin_print(*args, **kwargs)
|
212 |
+
|
213 |
+
__builtin__.print = print
|
214 |
+
|
215 |
+
|
216 |
+
def is_dist_avail_and_initialized():
|
217 |
+
if not dist.is_available():
|
218 |
+
return False
|
219 |
+
if not dist.is_initialized():
|
220 |
+
return False
|
221 |
+
return True
|
222 |
+
|
223 |
+
|
224 |
+
def get_world_size():
|
225 |
+
if not is_dist_avail_and_initialized():
|
226 |
+
return 1
|
227 |
+
return dist.get_world_size()
|
228 |
+
|
229 |
+
|
230 |
+
def get_rank():
|
231 |
+
if not is_dist_avail_and_initialized():
|
232 |
+
return 0
|
233 |
+
return dist.get_rank()
|
234 |
+
|
235 |
+
|
236 |
+
def is_main_process():
|
237 |
+
return get_rank() == 0
|
238 |
+
|
239 |
+
|
240 |
+
def save_on_master(*args, **kwargs):
|
241 |
+
if is_main_process():
|
242 |
+
torch.save(*args, **kwargs)
|
243 |
+
|
244 |
+
|
245 |
+
def init_distributed_mode(args):
|
246 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
247 |
+
args.rank = int(os.environ["RANK"])
|
248 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
249 |
+
args.gpu = int(os.environ["LOCAL_RANK"])
|
250 |
+
elif "SLURM_PROCID" in os.environ:
|
251 |
+
args.rank = int(os.environ["SLURM_PROCID"])
|
252 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
253 |
+
elif hasattr(args, "rank"):
|
254 |
+
pass
|
255 |
+
else:
|
256 |
+
print("Not using distributed mode")
|
257 |
+
args.distributed = False
|
258 |
+
return
|
259 |
+
|
260 |
+
args.distributed = True
|
261 |
+
|
262 |
+
torch.cuda.set_device(args.gpu)
|
263 |
+
args.dist_backend = "nccl"
|
264 |
+
print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
|
265 |
+
torch.distributed.init_process_group(
|
266 |
+
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
267 |
+
)
|
268 |
+
torch.distributed.barrier()
|
269 |
+
setup_for_distributed(args.rank == 0)
|
270 |
+
|
271 |
+
|
272 |
+
def average_checkpoints(inputs):
|
273 |
+
"""Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from:
|
274 |
+
https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16
|
275 |
+
|
276 |
+
Args:
|
277 |
+
inputs (List[str]): An iterable of string paths of checkpoints to load from.
|
278 |
+
Returns:
|
279 |
+
A dict of string keys mapping to various values. The 'model' key
|
280 |
+
from the returned dict should correspond to an OrderedDict mapping
|
281 |
+
string parameter names to torch Tensors.
|
282 |
+
"""
|
283 |
+
params_dict = OrderedDict()
|
284 |
+
params_keys = None
|
285 |
+
new_state = None
|
286 |
+
num_models = len(inputs)
|
287 |
+
for fpath in inputs:
|
288 |
+
with open(fpath, "rb") as f:
|
289 |
+
state = torch.load(
|
290 |
+
f,
|
291 |
+
map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")),
|
292 |
+
)
|
293 |
+
# Copies over the settings from the first checkpoint
|
294 |
+
if new_state is None:
|
295 |
+
new_state = state
|
296 |
+
model_params = state["model"]
|
297 |
+
model_params_keys = list(model_params.keys())
|
298 |
+
if params_keys is None:
|
299 |
+
params_keys = model_params_keys
|
300 |
+
elif params_keys != model_params_keys:
|
301 |
+
raise KeyError(
|
302 |
+
f"For checkpoint {f}, expected list of params: {params_keys}, but found: {model_params_keys}"
|
303 |
+
)
|
304 |
+
for k in params_keys:
|
305 |
+
p = model_params[k]
|
306 |
+
if isinstance(p, torch.HalfTensor):
|
307 |
+
p = p.float()
|
308 |
+
if k not in params_dict:
|
309 |
+
params_dict[k] = p.clone()
|
310 |
+
# NOTE: clone() is needed in case of p is a shared parameter
|
311 |
+
else:
|
312 |
+
params_dict[k] += p
|
313 |
+
averaged_params = OrderedDict()
|
314 |
+
for k, v in params_dict.items():
|
315 |
+
averaged_params[k] = v
|
316 |
+
if averaged_params[k].is_floating_point():
|
317 |
+
averaged_params[k].div_(num_models)
|
318 |
+
else:
|
319 |
+
averaged_params[k] //= num_models
|
320 |
+
new_state["model"] = averaged_params
|
321 |
+
return new_state
|
322 |
+
|
323 |
+
|
324 |
+
def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True):
|
325 |
+
"""
|
326 |
+
This method can be used to prepare weights files for new models. It receives as
|
327 |
+
input a model architecture and a checkpoint from the training script and produces
|
328 |
+
a file with the weights ready for release.
|
329 |
+
|
330 |
+
Examples:
|
331 |
+
from torchvision import models as M
|
332 |
+
|
333 |
+
# Classification
|
334 |
+
model = M.mobilenet_v3_large(weights=None)
|
335 |
+
print(store_model_weights(model, './class.pth'))
|
336 |
+
|
337 |
+
# Quantized Classification
|
338 |
+
model = M.quantization.mobilenet_v3_large(weights=None, quantize=False)
|
339 |
+
model.fuse_model(is_qat=True)
|
340 |
+
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
|
341 |
+
_ = torch.ao.quantization.prepare_qat(model, inplace=True)
|
342 |
+
print(store_model_weights(model, './qat.pth'))
|
343 |
+
|
344 |
+
# Object Detection
|
345 |
+
model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None)
|
346 |
+
print(store_model_weights(model, './obj.pth'))
|
347 |
+
|
348 |
+
# Segmentation
|
349 |
+
model = M.segmentation.deeplabv3_mobilenet_v3_large(weights=None, weights_backbone=None, aux_loss=True)
|
350 |
+
print(store_model_weights(model, './segm.pth', strict=False))
|
351 |
+
|
352 |
+
Args:
|
353 |
+
model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes.
|
354 |
+
checkpoint_path (str): The path of the checkpoint we will load.
|
355 |
+
checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored.
|
356 |
+
Default: "model".
|
357 |
+
strict (bool): whether to strictly enforce that the keys
|
358 |
+
in :attr:`state_dict` match the keys returned by this module's
|
359 |
+
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
|
360 |
+
|
361 |
+
Returns:
|
362 |
+
output_path (str): The location where the weights are saved.
|
363 |
+
"""
|
364 |
+
# Store the new model next to the checkpoint_path
|
365 |
+
checkpoint_path = os.path.abspath(checkpoint_path)
|
366 |
+
output_dir = os.path.dirname(checkpoint_path)
|
367 |
+
|
368 |
+
# Deep copy to avoid side-effects on the model object.
|
369 |
+
model = copy.deepcopy(model)
|
370 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
371 |
+
|
372 |
+
# Load the weights to the model to validate that everything works
|
373 |
+
# and remove unnecessary weights (such as auxiliaries, etc)
|
374 |
+
if checkpoint_key == "model_ema":
|
375 |
+
del checkpoint[checkpoint_key]["n_averaged"]
|
376 |
+
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.")
|
377 |
+
model.load_state_dict(checkpoint[checkpoint_key], strict=strict)
|
378 |
+
|
379 |
+
tmp_path = os.path.join(output_dir, str(model.__hash__()))
|
380 |
+
torch.save(model.state_dict(), tmp_path)
|
381 |
+
|
382 |
+
sha256_hash = hashlib.sha256()
|
383 |
+
with open(tmp_path, "rb") as f:
|
384 |
+
# Read and update hash string value in blocks of 4K
|
385 |
+
for byte_block in iter(lambda: f.read(4096), b""):
|
386 |
+
sha256_hash.update(byte_block)
|
387 |
+
hh = sha256_hash.hexdigest()
|
388 |
+
|
389 |
+
output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth")
|
390 |
+
os.replace(tmp_path, output_path)
|
391 |
+
|
392 |
+
return output_path
|
393 |
+
|
394 |
+
|
395 |
+
def reduce_across_processes(val):
|
396 |
+
if not is_dist_avail_and_initialized():
|
397 |
+
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
|
398 |
+
return torch.tensor(val)
|
399 |
+
|
400 |
+
t = torch.tensor(val, device="cuda")
|
401 |
+
dist.barrier()
|
402 |
+
dist.all_reduce(t)
|
403 |
+
return t
|
404 |
+
|
405 |
+
|
406 |
+
def set_weight_decay(
|
407 |
+
model: torch.nn.Module,
|
408 |
+
weight_decay: float,
|
409 |
+
norm_weight_decay: Optional[float] = None,
|
410 |
+
norm_classes: Optional[List[type]] = None,
|
411 |
+
custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None,
|
412 |
+
):
|
413 |
+
if not norm_classes:
|
414 |
+
norm_classes = [
|
415 |
+
torch.nn.modules.batchnorm._BatchNorm,
|
416 |
+
torch.nn.LayerNorm,
|
417 |
+
torch.nn.GroupNorm,
|
418 |
+
torch.nn.modules.instancenorm._InstanceNorm,
|
419 |
+
torch.nn.LocalResponseNorm,
|
420 |
+
]
|
421 |
+
norm_classes = tuple(norm_classes)
|
422 |
+
|
423 |
+
params = {
|
424 |
+
"other": [],
|
425 |
+
"norm": [],
|
426 |
+
}
|
427 |
+
params_weight_decay = {
|
428 |
+
"other": weight_decay,
|
429 |
+
"norm": norm_weight_decay,
|
430 |
+
}
|
431 |
+
custom_keys = []
|
432 |
+
if custom_keys_weight_decay is not None:
|
433 |
+
for key, weight_decay in custom_keys_weight_decay:
|
434 |
+
params[key] = []
|
435 |
+
params_weight_decay[key] = weight_decay
|
436 |
+
custom_keys.append(key)
|
437 |
+
|
438 |
+
def _add_params(module, prefix=""):
|
439 |
+
for name, p in module.named_parameters(recurse=False):
|
440 |
+
if not p.requires_grad:
|
441 |
+
continue
|
442 |
+
is_custom_key = False
|
443 |
+
for key in custom_keys:
|
444 |
+
target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name
|
445 |
+
if key == target_name:
|
446 |
+
params[key].append(p)
|
447 |
+
is_custom_key = True
|
448 |
+
break
|
449 |
+
if not is_custom_key:
|
450 |
+
if norm_weight_decay is not None and isinstance(module, norm_classes):
|
451 |
+
params["norm"].append(p)
|
452 |
+
else:
|
453 |
+
params["other"].append(p)
|
454 |
+
|
455 |
+
for child_name, child_module in module.named_children():
|
456 |
+
child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
|
457 |
+
_add_params(child_module, prefix=child_prefix)
|
458 |
+
|
459 |
+
_add_params(model)
|
460 |
+
|
461 |
+
param_groups = []
|
462 |
+
for key in params:
|
463 |
+
if len(params[key]) > 0:
|
464 |
+
param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
|
465 |
+
return param_groups
|
hpo-examples/image-classification/vit_b_16/model_4.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:29f1bd991a3f27f7982e8700f31332cb94c4783b83e240fc71f1eca03d4eb468
|
3 |
+
size 1053172110
|
hpo-examples/question-answering/qa/README.md
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: transformers
|
3 |
+
license: apache-2.0
|
4 |
+
base_model: google-bert/bert-base-uncased
|
5 |
+
tags:
|
6 |
+
- generated_from_trainer
|
7 |
+
datasets:
|
8 |
+
- squad
|
9 |
+
model-index:
|
10 |
+
- name: baseline
|
11 |
+
results: []
|
12 |
+
---
|
13 |
+
|
14 |
+
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
15 |
+
should probably proofread and complete it, then remove this comment. -->
|
16 |
+
|
17 |
+
# baseline
|
18 |
+
|
19 |
+
This model is a fine-tuned version of [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) on the squad dataset.
|
20 |
+
|
21 |
+
## Model description
|
22 |
+
|
23 |
+
More information needed
|
24 |
+
|
25 |
+
## Intended uses & limitations
|
26 |
+
|
27 |
+
More information needed
|
28 |
+
|
29 |
+
## Training and evaluation data
|
30 |
+
|
31 |
+
More information needed
|
32 |
+
|
33 |
+
## Training procedure
|
34 |
+
|
35 |
+
### Training hyperparameters
|
36 |
+
|
37 |
+
The following hyperparameters were used during training:
|
38 |
+
- learning_rate: 3e-05
|
39 |
+
- train_batch_size: 12
|
40 |
+
- eval_batch_size: 8
|
41 |
+
- seed: 42
|
42 |
+
- optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
|
43 |
+
- lr_scheduler_type: linear
|
44 |
+
- num_epochs: 2.0
|
45 |
+
|
46 |
+
### Training results
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
### Framework versions
|
51 |
+
|
52 |
+
- Transformers 4.49.0
|
53 |
+
- Pytorch 2.6.0+cu118
|
54 |
+
- Datasets 3.3.1
|
55 |
+
- Tokenizers 0.21.0
|
hpo-examples/question-answering/qa/all_results.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"epoch": 2.0,
|
3 |
+
"eval_exact_match": 81.49479659413434,
|
4 |
+
"eval_f1": 88.62945564424126,
|
5 |
+
"eval_runtime": 61.0301,
|
6 |
+
"eval_samples": 10784,
|
7 |
+
"eval_samples_per_second": 176.7,
|
8 |
+
"eval_steps_per_second": 22.087,
|
9 |
+
"total_flos": 3.541929151120589e+16,
|
10 |
+
"train_loss": 1.148573803161563,
|
11 |
+
"train_runtime": 3245.3985,
|
12 |
+
"train_samples": 88524,
|
13 |
+
"train_samples_per_second": 54.554,
|
14 |
+
"train_steps_per_second": 4.546
|
15 |
+
}
|
hpo-examples/question-answering/qa/config.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "google-bert/bert-base-uncased",
|
3 |
+
"architectures": [
|
4 |
+
"BertForQuestionAnswering"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"classifier_dropout": null,
|
8 |
+
"gradient_checkpointing": false,
|
9 |
+
"hidden_act": "gelu",
|
10 |
+
"hidden_dropout_prob": 0.1,
|
11 |
+
"hidden_size": 768,
|
12 |
+
"initializer_range": 0.02,
|
13 |
+
"intermediate_size": 3072,
|
14 |
+
"layer_norm_eps": 1e-12,
|
15 |
+
"max_position_embeddings": 512,
|
16 |
+
"model_type": "bert",
|
17 |
+
"num_attention_heads": 12,
|
18 |
+
"num_hidden_layers": 12,
|
19 |
+
"pad_token_id": 0,
|
20 |
+
"position_embedding_type": "absolute",
|
21 |
+
"torch_dtype": "float32",
|
22 |
+
"transformers_version": "4.49.0",
|
23 |
+
"type_vocab_size": 2,
|
24 |
+
"use_cache": true,
|
25 |
+
"vocab_size": 30522
|
26 |
+
}
|
hpo-examples/question-answering/qa/eval_nbest_predictions.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8b8d44953cbe0ce20d1d1b62b72e7adba18bf1dc81d055492e22bfa21ff46657
|
3 |
+
size 49596120
|
hpo-examples/question-answering/qa/eval_predictions.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
hpo-examples/question-answering/qa/eval_results.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"epoch": 2.0,
|
3 |
+
"eval_exact_match": 81.49479659413434,
|
4 |
+
"eval_f1": 88.62945564424126,
|
5 |
+
"eval_runtime": 61.0301,
|
6 |
+
"eval_samples": 10784,
|
7 |
+
"eval_samples_per_second": 176.7,
|
8 |
+
"eval_steps_per_second": 22.087
|
9 |
+
}
|
hpo-examples/question-answering/qa/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:38003bd65e4bfa70dd16886f29af7ab00d1aa0ae4de191b0a7de4d7883d17dde
|
3 |
+
size 442683784
|
hpo-examples/question-answering/qa/runs/May15_03-24-14_cs-Precision-7960-Tower/events.out.tfevents.1747293859.cs-Precision-7960-Tower.147971.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:36bfca6273a2422943de7b634cf75efd69b8e92079abe84df9e9c9e026d497f6
|
3 |
+
size 11535
|
hpo-examples/question-answering/qa/runs/May15_03-24-14_cs-Precision-7960-Tower/events.out.tfevents.1747297197.cs-Precision-7960-Tower.147971.1
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:259c79a03ba9c522b1fd728e92dae5cfc31c6cd73b2377d124749c83a0163910
|
3 |
+
size 412
|
hpo-examples/question-answering/qa/special_tokens_map.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"mask_token": "[MASK]",
|
4 |
+
"pad_token": "[PAD]",
|
5 |
+
"sep_token": "[SEP]",
|
6 |
+
"unk_token": "[UNK]"
|
7 |
+
}
|
hpo-examples/question-answering/qa/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
hpo-examples/question-answering/qa/tokenizer_config.json
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "[PAD]",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"100": {
|
12 |
+
"content": "[UNK]",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"101": {
|
20 |
+
"content": "[CLS]",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"102": {
|
28 |
+
"content": "[SEP]",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"103": {
|
36 |
+
"content": "[MASK]",
|
37 |
+
"lstrip": false,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"clean_up_tokenization_spaces": false,
|
45 |
+
"cls_token": "[CLS]",
|
46 |
+
"do_lower_case": true,
|
47 |
+
"extra_special_tokens": {},
|
48 |
+
"mask_token": "[MASK]",
|
49 |
+
"model_max_length": 512,
|
50 |
+
"pad_token": "[PAD]",
|
51 |
+
"sep_token": "[SEP]",
|
52 |
+
"strip_accents": null,
|
53 |
+
"tokenize_chinese_chars": true,
|
54 |
+
"tokenizer_class": "BertTokenizer",
|
55 |
+
"unk_token": "[UNK]"
|
56 |
+
}
|
hpo-examples/question-answering/qa/train_results.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"epoch": 2.0,
|
3 |
+
"total_flos": 3.541929151120589e+16,
|
4 |
+
"train_loss": 1.148573803161563,
|
5 |
+
"train_runtime": 3245.3985,
|
6 |
+
"train_samples": 88524,
|
7 |
+
"train_samples_per_second": 54.554,
|
8 |
+
"train_steps_per_second": 4.546
|
9 |
+
}
|
hpo-examples/question-answering/qa/trainer_state.json
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"best_metric": null,
|
3 |
+
"best_model_checkpoint": null,
|
4 |
+
"epoch": 2.0,
|
5 |
+
"eval_steps": 500,
|
6 |
+
"global_step": 14754,
|
7 |
+
"is_hyper_param_search": false,
|
8 |
+
"is_local_process_zero": true,
|
9 |
+
"is_world_process_zero": true,
|
10 |
+
"log_history": [
|
11 |
+
{
|
12 |
+
"epoch": 0.06777822963264199,
|
13 |
+
"grad_norm": 31.397275924682617,
|
14 |
+
"learning_rate": 2.8983326555510372e-05,
|
15 |
+
"loss": 2.7299,
|
16 |
+
"step": 500
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"epoch": 0.13555645926528398,
|
20 |
+
"grad_norm": 25.8492431640625,
|
21 |
+
"learning_rate": 2.796665311102074e-05,
|
22 |
+
"loss": 1.752,
|
23 |
+
"step": 1000
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"epoch": 0.203334688897926,
|
27 |
+
"grad_norm": 29.627431869506836,
|
28 |
+
"learning_rate": 2.694997966653111e-05,
|
29 |
+
"loss": 1.5588,
|
30 |
+
"step": 1500
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"epoch": 0.27111291853056796,
|
34 |
+
"grad_norm": 21.147193908691406,
|
35 |
+
"learning_rate": 2.593330622204148e-05,
|
36 |
+
"loss": 1.5014,
|
37 |
+
"step": 2000
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"epoch": 0.33889114816321,
|
41 |
+
"grad_norm": 17.81966781616211,
|
42 |
+
"learning_rate": 2.491663277755185e-05,
|
43 |
+
"loss": 1.4768,
|
44 |
+
"step": 2500
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"epoch": 0.406669377795852,
|
48 |
+
"grad_norm": 20.26822853088379,
|
49 |
+
"learning_rate": 2.389995933306222e-05,
|
50 |
+
"loss": 1.4064,
|
51 |
+
"step": 3000
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"epoch": 0.47444760742849396,
|
55 |
+
"grad_norm": 16.216028213500977,
|
56 |
+
"learning_rate": 2.288328588857259e-05,
|
57 |
+
"loss": 1.3502,
|
58 |
+
"step": 3500
|
59 |
+
},
|
60 |
+
{
|
61 |
+
"epoch": 0.5422258370611359,
|
62 |
+
"grad_norm": 17.930505752563477,
|
63 |
+
"learning_rate": 2.1866612444082963e-05,
|
64 |
+
"loss": 1.3101,
|
65 |
+
"step": 4000
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"epoch": 0.6100040666937779,
|
69 |
+
"grad_norm": 26.499574661254883,
|
70 |
+
"learning_rate": 2.084993899959333e-05,
|
71 |
+
"loss": 1.2922,
|
72 |
+
"step": 4500
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"epoch": 0.67778229632642,
|
76 |
+
"grad_norm": 26.83368492126465,
|
77 |
+
"learning_rate": 1.9833265555103702e-05,
|
78 |
+
"loss": 1.3053,
|
79 |
+
"step": 5000
|
80 |
+
},
|
81 |
+
{
|
82 |
+
"epoch": 0.745560525959062,
|
83 |
+
"grad_norm": 22.85872459411621,
|
84 |
+
"learning_rate": 1.8816592110614073e-05,
|
85 |
+
"loss": 1.2555,
|
86 |
+
"step": 5500
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"epoch": 0.813338755591704,
|
90 |
+
"grad_norm": 23.48080825805664,
|
91 |
+
"learning_rate": 1.779991866612444e-05,
|
92 |
+
"loss": 1.2068,
|
93 |
+
"step": 6000
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"epoch": 0.8811169852243459,
|
97 |
+
"grad_norm": 20.919252395629883,
|
98 |
+
"learning_rate": 1.6783245221634812e-05,
|
99 |
+
"loss": 1.1991,
|
100 |
+
"step": 6500
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"epoch": 0.9488952148569879,
|
104 |
+
"grad_norm": 23.9005126953125,
|
105 |
+
"learning_rate": 1.576657177714518e-05,
|
106 |
+
"loss": 1.2156,
|
107 |
+
"step": 7000
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"epoch": 1.01667344448963,
|
111 |
+
"grad_norm": 22.660743713378906,
|
112 |
+
"learning_rate": 1.4749898332655551e-05,
|
113 |
+
"loss": 1.0827,
|
114 |
+
"step": 7500
|
115 |
+
},
|
116 |
+
{
|
117 |
+
"epoch": 1.0844516741222718,
|
118 |
+
"grad_norm": 25.28419303894043,
|
119 |
+
"learning_rate": 1.373322488816592e-05,
|
120 |
+
"loss": 0.8481,
|
121 |
+
"step": 8000
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"epoch": 1.152229903754914,
|
125 |
+
"grad_norm": 14.510698318481445,
|
126 |
+
"learning_rate": 1.271655144367629e-05,
|
127 |
+
"loss": 0.872,
|
128 |
+
"step": 8500
|
129 |
+
},
|
130 |
+
{
|
131 |
+
"epoch": 1.2200081333875559,
|
132 |
+
"grad_norm": 29.12289810180664,
|
133 |
+
"learning_rate": 1.1699877999186661e-05,
|
134 |
+
"loss": 0.8375,
|
135 |
+
"step": 9000
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"epoch": 1.287786363020198,
|
139 |
+
"grad_norm": 19.038454055786133,
|
140 |
+
"learning_rate": 1.0683204554697033e-05,
|
141 |
+
"loss": 0.8464,
|
142 |
+
"step": 9500
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"epoch": 1.35556459265284,
|
146 |
+
"grad_norm": 21.09101676940918,
|
147 |
+
"learning_rate": 9.666531110207402e-06,
|
148 |
+
"loss": 0.8746,
|
149 |
+
"step": 10000
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"epoch": 1.4233428222854818,
|
153 |
+
"grad_norm": 20.79250144958496,
|
154 |
+
"learning_rate": 8.649857665717772e-06,
|
155 |
+
"loss": 0.8776,
|
156 |
+
"step": 10500
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"epoch": 1.491121051918124,
|
160 |
+
"grad_norm": 21.217571258544922,
|
161 |
+
"learning_rate": 7.633184221228141e-06,
|
162 |
+
"loss": 0.8523,
|
163 |
+
"step": 11000
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"epoch": 1.5588992815507658,
|
167 |
+
"grad_norm": 15.557079315185547,
|
168 |
+
"learning_rate": 6.616510776738511e-06,
|
169 |
+
"loss": 0.8387,
|
170 |
+
"step": 11500
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"epoch": 1.626677511183408,
|
174 |
+
"grad_norm": 14.53345012664795,
|
175 |
+
"learning_rate": 5.5998373322488825e-06,
|
176 |
+
"loss": 0.8377,
|
177 |
+
"step": 12000
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"epoch": 1.6944557408160499,
|
181 |
+
"grad_norm": 26.921611785888672,
|
182 |
+
"learning_rate": 4.583163887759252e-06,
|
183 |
+
"loss": 0.8449,
|
184 |
+
"step": 12500
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"epoch": 1.7622339704486918,
|
188 |
+
"grad_norm": 12.789366722106934,
|
189 |
+
"learning_rate": 3.566490443269622e-06,
|
190 |
+
"loss": 0.8547,
|
191 |
+
"step": 13000
|
192 |
+
},
|
193 |
+
{
|
194 |
+
"epoch": 1.830012200081334,
|
195 |
+
"grad_norm": 37.19759750366211,
|
196 |
+
"learning_rate": 2.549816998779992e-06,
|
197 |
+
"loss": 0.818,
|
198 |
+
"step": 13500
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"epoch": 1.8977904297139758,
|
202 |
+
"grad_norm": 14.62682819366455,
|
203 |
+
"learning_rate": 1.533143554290362e-06,
|
204 |
+
"loss": 0.8128,
|
205 |
+
"step": 14000
|
206 |
+
},
|
207 |
+
{
|
208 |
+
"epoch": 1.965568659346618,
|
209 |
+
"grad_norm": 21.051790237426758,
|
210 |
+
"learning_rate": 5.164701098007319e-07,
|
211 |
+
"loss": 0.8115,
|
212 |
+
"step": 14500
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"epoch": 2.0,
|
216 |
+
"step": 14754,
|
217 |
+
"total_flos": 3.541929151120589e+16,
|
218 |
+
"train_loss": 1.148573803161563,
|
219 |
+
"train_runtime": 3245.3985,
|
220 |
+
"train_samples_per_second": 54.554,
|
221 |
+
"train_steps_per_second": 4.546
|
222 |
+
}
|
223 |
+
],
|
224 |
+
"logging_steps": 500,
|
225 |
+
"max_steps": 14754,
|
226 |
+
"num_input_tokens_seen": 0,
|
227 |
+
"num_train_epochs": 2,
|
228 |
+
"save_steps": 500,
|
229 |
+
"stateful_callbacks": {
|
230 |
+
"TrainerControl": {
|
231 |
+
"args": {
|
232 |
+
"should_epoch_stop": false,
|
233 |
+
"should_evaluate": false,
|
234 |
+
"should_log": false,
|
235 |
+
"should_save": true,
|
236 |
+
"should_training_stop": true
|
237 |
+
},
|
238 |
+
"attributes": {}
|
239 |
+
}
|
240 |
+
},
|
241 |
+
"total_flos": 3.541929151120589e+16,
|
242 |
+
"train_batch_size": 12,
|
243 |
+
"trial_name": null,
|
244 |
+
"trial_params": null
|
245 |
+
}
|
hpo-examples/question-answering/qa/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fe8e61ba1ca1cb106ca9adca5e9262fa9a262238814728a69256855c78c32f51
|
3 |
+
size 5304
|
hpo-examples/question-answering/qa/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
hpo-examples/question-answering/requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate >= 0.12.0
|
2 |
+
datasets >= 1.8.0
|
3 |
+
torch >= 1.3.0
|
4 |
+
evaluate
|