yuntian-deng commited on
Commit
e531498
·
1 Parent(s): a076741

Upload 2e5_debug_gpt_firstframe_posmap_longtrainh200.yaml

Browse files
2e5_debug_gpt_firstframe_posmap_longtrainh200.yaml ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ save_path: saved_fixcursor_lr2e5_debug_gpt_firstframe_posmap_longtrainh200
2
+
3
+ model:
4
+ base_learning_rate: 2.0e-05
5
+ target: latent_diffusion.ldm.models.diffusion.ddpm.LatentDiffusion
6
+ params:
7
+ linear_start: 0.0015
8
+ linear_end: 0.0195
9
+ num_timesteps_cond: 1
10
+ log_every_t: 200
11
+ timesteps: 1000
12
+ first_stage_key: image
13
+ cond_stage_key: action_
14
+ scheduler_sampling_rate: 0.0
15
+ hybrid_key: c_concat
16
+ image_size: 64
17
+ channels: 3
18
+ cond_stage_trainable: true
19
+ conditioning_key: hybrid
20
+ monitor: val/loss_simple_ema
21
+
22
+ unet_config:
23
+ target: latent_diffusion.ldm.modules.diffusionmodules.openaimodel.UNetModel
24
+ params:
25
+ image_size: 64
26
+ in_channels: 25
27
+ out_channels: 3
28
+ model_channels: 192
29
+ attention_resolutions:
30
+ - 8
31
+ - 4
32
+ - 2
33
+ num_res_blocks: 2
34
+ channel_mult:
35
+ - 1
36
+ - 2
37
+ - 3
38
+ - 5
39
+ num_head_channels: 32
40
+ use_spatial_transformer: true
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ first_stage_config:
44
+ target: latent_diffusion.ldm.models.autoencoder.VQModelInterface
45
+ params:
46
+ embed_dim: 3
47
+ n_embed: 8192
48
+ monitor: val/rec_loss
49
+
50
+ ddconfig:
51
+ double_z: false
52
+ z_channels: 3
53
+ resolution: 256
54
+ in_channels: 3
55
+ out_ch: 3
56
+ ch: 128
57
+ ch_mult:
58
+ - 1
59
+ - 2
60
+ - 4
61
+ num_res_blocks: 2
62
+ attn_resolutions: []
63
+ dropout: 0.0
64
+ lossconfig:
65
+ target: torch.nn.Identity
66
+
67
+ cond_stage_config:
68
+ target: latent_diffusion.ldm.modules.encoders.modules.GPTEmbedder
69
+ params:
70
+ n_embed: 768
71
+ n_layer: 12
72
+
73
+ data:
74
+ target: data.data_processing.datasets.DataModule
75
+ params:
76
+ batch_size: 64
77
+ num_workers: 1
78
+ wrap: false
79
+ shuffle: True
80
+ drop_last: True
81
+ pin_memory: True
82
+ prefetch_factor: 2
83
+ persistent_workers: True
84
+ train:
85
+ target: data.data_processing.datasets.ActionsData
86
+ params:
87
+ data_csv_path: train_dataset/train_dataset_14frames_firstframe.csv
88
+ validation:
89
+ target: data.data_processing.datasets.ActionsData
90
+ params:
91
+ data_csv_path: train_dataset/train_dataset_14frames_firstframe.csv
92
+
93
+ lightning:
94
+ trainer:
95
+ benchmark: False
96
+ max_epochs: 16
97
+ limit_val_batches: 0
98
+ accelerator: gpu
99
+ gpus: 1
100
+ accumulate_grad_batches: 2
101
+ gradient_clip_val: 1
102
+ checkpoint_callback: True