yuntian-deng commited on
Commit
6bb4da6
·
1 Parent(s): 29d8c86

Create config_identity.yaml

Browse files
Files changed (1) hide show
  1. config_identity.yaml +83 -0
config_identity.yaml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 2.0e-05
3
+ target: latent_diffusion.ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ cond_stage_key: action_
12
+ scheduler_sampling_rate: 0.0
13
+ hybrid_key: c_concat
14
+ image_size: 256
15
+ channels: 3
16
+ cond_stage_trainable: true
17
+ conditioning_key: hybrid
18
+ monitor: val/loss_simple_ema
19
+
20
+ unet_config:
21
+ target: latent_diffusion.ldm.modules.diffusionmodules.openaimodel.UNetModel
22
+ params:
23
+ image_size: 256
24
+ in_channels: 24
25
+ out_channels: 3
26
+ model_channels: 192
27
+ attention_resolutions:
28
+ - 32
29
+ - 16
30
+ - 8
31
+ num_res_blocks: 2
32
+ channel_mult:
33
+ - 1
34
+ - 2
35
+ - 3
36
+ - 5
37
+ num_head_channels: 32
38
+ use_spatial_transformer: true
39
+ transformer_depth: 1
40
+ context_dim: 768
41
+
42
+ first_stage_config:
43
+ target: latent_diffusion.ldm.models.autoencoder.IdentityFirstStage
44
+ params:
45
+ embed_dim: 3
46
+ n_embed: 8192
47
+
48
+ cond_stage_config:
49
+ target: latent_diffusion.ldm.modules.encoders.modules.GPTEmbedder
50
+ params:
51
+ n_embed: 768
52
+ n_layer: 12
53
+
54
+ data:
55
+ target: data.data_processing.datasets.DataModule
56
+ params:
57
+ batch_size: 16
58
+ num_workers: 1
59
+ wrap: false
60
+ shuffle: True
61
+ drop_last: True
62
+ pin_memory: True
63
+ prefetch_factor: 2
64
+ persistent_workers: True
65
+ train:
66
+ target: data.data_processing.datasets.ActionsData
67
+ params:
68
+ data_csv_path: train_dataset/train_dataset_14frames_firstframe.csv
69
+ validation:
70
+ target: data.data_processing.datasets.ActionsData
71
+ params:
72
+ data_csv_path: train_dataset/train_dataset_14frames_firstframe.csv
73
+
74
+ lightning:
75
+ trainer:
76
+ benchmark: False
77
+ max_epochs: 4
78
+ limit_val_batches: 0
79
+ accelerator: gpu
80
+ gpus: 1
81
+ accumulate_grad_batches: 8
82
+ gradient_clip_val: 1
83
+ checkpoint_callback: True