Spaces:
Runtime error
Runtime error
update for sagemaker
Browse files- README.md +15 -12
- src/train_unconditional.py +5 -5
README.md
CHANGED
|
@@ -1,19 +1,22 @@
|
|
| 1 |
# audio-diffusion
|
| 2 |
```bash
|
|
|
|
|
|
|
|
|
|
| 3 |
python src/audio_to_images.py \
|
| 4 |
-
--resolution
|
| 5 |
-
--input_dir
|
| 6 |
-
--output_dir
|
| 7 |
```
|
| 8 |
```bash
|
| 9 |
accelerate launch src/train_unconditional.py \
|
| 10 |
-
--dataset_name
|
| 11 |
-
--resolution
|
| 12 |
-
--output_dir
|
| 13 |
-
--train_batch_size
|
| 14 |
-
--num_epochs
|
| 15 |
-
--gradient_accumulation_steps
|
| 16 |
-
--learning_rate
|
| 17 |
-
--lr_warmup_steps
|
| 18 |
-
--mixed_precision
|
| 19 |
```
|
|
|
|
| 1 |
# audio-diffusion
|
| 2 |
```bash
|
| 3 |
+
accelerate config
|
| 4 |
+
```
|
| 5 |
+
```bash
|
| 6 |
python src/audio_to_images.py \
|
| 7 |
+
--resolution 256 \
|
| 8 |
+
--input_dir path-to-audio-files \
|
| 9 |
+
--output_dir data-256
|
| 10 |
```
|
| 11 |
```bash
|
| 12 |
accelerate launch src/train_unconditional.py \
|
| 13 |
+
--dataset_name data-256 \
|
| 14 |
+
--resolution 256 \
|
| 15 |
+
--output_dir ddpm-ema-audio-256 \
|
| 16 |
+
--train_batch_size 16 \
|
| 17 |
+
--num_epochs 100 \
|
| 18 |
+
--gradient_accumulation_steps 1 \
|
| 19 |
+
--learning_rate 1e-4 \
|
| 20 |
+
--lr_warmup_steps 500 \
|
| 21 |
+
--mixed_precision no
|
| 22 |
```
|
src/train_unconditional.py
CHANGED
|
@@ -253,7 +253,7 @@ if __name__ == "__main__":
|
|
| 253 |
help="A folder containing the training data.",
|
| 254 |
)
|
| 255 |
parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
|
| 256 |
-
parser.add_argument("--overwrite_output_dir",
|
| 257 |
parser.add_argument("--cache_dir", type=str, default=None)
|
| 258 |
parser.add_argument("--resolution", type=int, default=64)
|
| 259 |
parser.add_argument("--train_batch_size", type=int, default=16)
|
|
@@ -269,15 +269,15 @@ if __name__ == "__main__":
|
|
| 269 |
parser.add_argument("--adam_beta2", type=float, default=0.999)
|
| 270 |
parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
|
| 271 |
parser.add_argument("--adam_epsilon", type=float, default=1e-08)
|
| 272 |
-
parser.add_argument("--use_ema",
|
| 273 |
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
|
| 274 |
parser.add_argument("--ema_power", type=float, default=3 / 4)
|
| 275 |
parser.add_argument("--ema_max_decay", type=float, default=0.9999)
|
| 276 |
-
parser.add_argument("--push_to_hub",
|
| 277 |
-
parser.add_argument("--use_auth_token",
|
| 278 |
parser.add_argument("--hub_token", type=str, default=None)
|
| 279 |
parser.add_argument("--hub_model_id", type=str, default=None)
|
| 280 |
-
parser.add_argument("--hub_private_repo",
|
| 281 |
parser.add_argument("--logging_dir", type=str, default="logs")
|
| 282 |
parser.add_argument(
|
| 283 |
"--mixed_precision",
|
|
|
|
| 253 |
help="A folder containing the training data.",
|
| 254 |
)
|
| 255 |
parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
|
| 256 |
+
parser.add_argument("--overwrite_output_dir", type=bool, default=False)
|
| 257 |
parser.add_argument("--cache_dir", type=str, default=None)
|
| 258 |
parser.add_argument("--resolution", type=int, default=64)
|
| 259 |
parser.add_argument("--train_batch_size", type=int, default=16)
|
|
|
|
| 269 |
parser.add_argument("--adam_beta2", type=float, default=0.999)
|
| 270 |
parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
|
| 271 |
parser.add_argument("--adam_epsilon", type=float, default=1e-08)
|
| 272 |
+
parser.add_argument("--use_ema", type=bool, default=True)
|
| 273 |
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
|
| 274 |
parser.add_argument("--ema_power", type=float, default=3 / 4)
|
| 275 |
parser.add_argument("--ema_max_decay", type=float, default=0.9999)
|
| 276 |
+
parser.add_argument("--push_to_hub", type=bool, default=False)
|
| 277 |
+
parser.add_argument("--use_auth_token", type=bool, default=False)
|
| 278 |
parser.add_argument("--hub_token", type=str, default=None)
|
| 279 |
parser.add_argument("--hub_model_id", type=str, default=None)
|
| 280 |
+
parser.add_argument("--hub_private_repo", type=bool, default=False)
|
| 281 |
parser.add_argument("--logging_dir", type=str, default="logs")
|
| 282 |
parser.add_argument(
|
| 283 |
"--mixed_precision",
|