da03 commited on
Commit
f0fd514
·
1 Parent(s): 2f5fac4
config_final_model_origunet_nospatial_x0.yaml ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ save_path: saved_standard_challenging_context32_nocond_cont_cont_all_cont_eval
2
+
3
+ model:
4
+ base_learning_rate: 8.0e-05
5
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
6
+ params:
7
+ parameterization: x0
8
+ linear_start: 0.0015
9
+ linear_end: 0.0195
10
+ num_timesteps_cond: 1
11
+ log_every_t: 200
12
+ timesteps: 1000
13
+ first_stage_key: image
14
+ cond_stage_key: action_
15
+ scheduler_sampling_rate: 0.0
16
+ hybrid_key: c_concat
17
+ image_size: [64, 48]
18
+ channels: 3
19
+ cond_stage_trainable: false
20
+ conditioning_key: hybrid
21
+ monitor: val/loss_simple_ema
22
+
23
+ unet_config:
24
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
25
+ params:
26
+ image_size: [64, 48]
27
+ in_channels: 48
28
+ out_channels: 16
29
+ model_channels: 192
30
+ attention_resolutions:
31
+ - 2
32
+ - 4
33
+ - 8
34
+ num_res_blocks: 2
35
+ channel_mult:
36
+ - 1
37
+ - 2
38
+ - 3
39
+ - 5
40
+ num_head_channels: 32
41
+ use_spatial_transformer: false
42
+ transformer_depth: 1
43
+
44
+ temporal_encoder_config:
45
+ target: ldm.modules.encoders.temporal_encoder.TemporalEncoder
46
+ params:
47
+ input_channels: 16
48
+ hidden_size: 4096
49
+ num_layers: 1
50
+ dropout: 0.1
51
+ output_channels: 32
52
+ output_height: 48
53
+ output_width: 64
54
+
55
+ first_stage_config:
56
+ target: ldm.models.autoencoder.AutoencoderKL
57
+ params:
58
+ embed_dim: 16
59
+ monitor: val/rec_loss
60
+ ddconfig:
61
+ double_z: true
62
+ z_channels: 16
63
+ resolution: 256
64
+ in_channels: 3
65
+ out_ch: 3
66
+ ch: 128
67
+ ch_mult:
68
+ - 1
69
+ - 2
70
+ - 4
71
+ - 4
72
+ num_res_blocks: 2
73
+ attn_resolutions: []
74
+ dropout: 0.0
75
+ lossconfig:
76
+ target: torch.nn.Identity
77
+
78
+ cond_stage_config: __is_unconditional__
79
+
80
+ data:
81
+ target: data.data_processing.datasets.DataModule
82
+ params:
83
+ batch_size: 8
84
+ num_workers: 1
85
+ wrap: false
86
+ shuffle: True
87
+ drop_last: True
88
+ pin_memory: True
89
+ prefetch_factor: 2
90
+ persistent_workers: True
91
+ train:
92
+ target: data.data_processing.datasets.ActionsData
93
+ params:
94
+ data_csv_path: desktop_sequences_filtered_with_desktop_1.5k.challenging.train.target_frames.csv
95
+ normalization: standard
96
+ context_length: 32
97
+ #validation:
98
+ # target: data.data_processing.datasets.ActionsData
99
+ # params:
100
+
101
+ lightning:
102
+ trainer:
103
+ benchmark: False
104
+ max_epochs: 6400
105
+ limit_val_batches: 0
106
+ accelerator: gpu
107
+ gpus: 1
108
+ accumulate_grad_batches: 999999
109
+ gradient_clip_val: 1
110
+ checkpoint_callback: True
main.py CHANGED
@@ -48,6 +48,7 @@ MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd
48
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-unfreezernn-198k"
49
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-674k"
50
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-74k"
 
51
 
52
 
53
  print (f'setting: DEBUG_MODE: {DEBUG_MODE}, DEBUG_MODE_2: {DEBUG_MODE_2}, NUM_MAX_FRAMES: {NUM_MAX_FRAMES}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}, MODEL_NAME: {MODEL_NAME}')
@@ -64,10 +65,14 @@ LATENT_DIMS = (16, SCREEN_HEIGHT // 8, SCREEN_WIDTH // 8)
64
  #model = initialize_model("config_final_model.yaml", "yuntian-deng/computer-model")
65
 
66
  if 'origunet' in MODEL_NAME:
67
- model = initialize_model("config_final_model_origunet_nospatial.yaml", MODEL_NAME)
 
 
 
68
  else:
69
  model = initialize_model("config_final_model.yaml", MODEL_NAME)
70
 
 
71
  model = model.to(device)
72
  #model = torch.compile(model)
73
  padding_image = torch.zeros(*LATENT_DIMS).unsqueeze(0).to(device)
 
48
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-unfreezernn-198k"
49
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-674k"
50
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-74k"
51
+ MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-x0"
52
 
53
 
54
  print (f'setting: DEBUG_MODE: {DEBUG_MODE}, DEBUG_MODE_2: {DEBUG_MODE_2}, NUM_MAX_FRAMES: {NUM_MAX_FRAMES}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}, MODEL_NAME: {MODEL_NAME}')
 
65
  #model = initialize_model("config_final_model.yaml", "yuntian-deng/computer-model")
66
 
67
  if 'origunet' in MODEL_NAME:
68
+ if 'x0' in MODEL_NAME:
69
+ model = initialize_model("config_final_model_origunet_nospatial_x0.yaml", MODEL_NAME)
70
+ else:
71
+ model = initialize_model("config_final_model_origunet_nospatial.yaml", MODEL_NAME)
72
  else:
73
  model = initialize_model("config_final_model.yaml", MODEL_NAME)
74
 
75
+
76
  model = model.to(device)
77
  #model = torch.compile(model)
78
  padding_image = torch.zeros(*LATENT_DIMS).unsqueeze(0).to(device)