da03 commited on
Commit
8f6c968
·
1 Parent(s): f474968
standard_challenging_context32_nocond_all.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ save_path: saved_standard_challenging_context32_nocond_cont_cont_all_cont
2
+
3
+ model:
4
+ base_learning_rate: 8.0e-05
5
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
6
+ params:
7
+ linear_start: 0.0015
8
+ linear_end: 0.0195
9
+ num_timesteps_cond: 1
10
+ log_every_t: 200
11
+ timesteps: 1000
12
+ first_stage_key: image
13
+ cond_stage_key: action_
14
+ scheduler_sampling_rate: 0.0
15
+ hybrid_key: c_concat
16
+ image_size: [64, 48]
17
+ channels: 3
18
+ cond_stage_trainable: false
19
+ conditioning_key: hybrid
20
+ monitor: val/loss_simple_ema
21
+
22
+ unet_config:
23
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
24
+ params:
25
+ image_size: [64, 48]
26
+ in_channels: 166
27
+ out_channels: 4
28
+ model_channels: 192
29
+ attention_resolutions:
30
+ - 8
31
+ - 4
32
+ - 2
33
+ num_res_blocks: 2
34
+ channel_mult:
35
+ - 1
36
+ - 2
37
+ - 3
38
+ - 5
39
+ num_head_channels: 32
40
+ use_spatial_transformer: false
41
+ transformer_depth: 1
42
+
43
+ first_stage_config:
44
+ target: ldm.models.autoencoder.AutoencoderKL
45
+ params:
46
+ embed_dim: 4
47
+ monitor: val/rec_loss
48
+ ddconfig:
49
+ double_z: true
50
+ z_channels: 4
51
+ resolution: 256
52
+ in_channels: 3
53
+ out_ch: 3
54
+ ch: 128
55
+ ch_mult:
56
+ - 1
57
+ - 2
58
+ - 4
59
+ - 4
60
+ num_res_blocks: 2
61
+ attn_resolutions: []
62
+ dropout: 0.0
63
+ lossconfig:
64
+ target: torch.nn.Identity
65
+
66
+ cond_stage_config: __is_unconditional__
67
+
68
+ data:
69
+ target: data.data_processing.datasets.DataModule
70
+ params:
71
+ batch_size: 256
72
+ num_workers: 16
73
+ wrap: false
74
+ shuffle: True
75
+ drop_last: True
76
+ pin_memory: True
77
+ prefetch_factor: 2
78
+ persistent_workers: True
79
+ train:
80
+ target: data.data_processing.datasets.ActionsData
81
+ params:
82
+ data_csv_path: train_dataset/train_dataset.target_frames.csv
83
+ normalization: standard
84
+ context_length: 32
85
+ #validation:
86
+ # target: data.data_processing.datasets.ActionsData
87
+ # params:
88
+ # data_csv_path: train_dataset/train_dataset_14frames_firstframe_allframes.csv
89
+
90
+ lightning:
91
+ trainer:
92
+ benchmark: False
93
+ max_epochs: 6400
94
+ limit_val_batches: 0
95
+ accelerator: gpu
96
+ gpus: 1
97
+ accumulate_grad_batches: 1
98
+ gradient_clip_val: 1
99
+ checkpoint_callback: True