da03 commited on
Commit
9df0e98
·
1 Parent(s): 3a81ef8
config_final_model_origunet_nospatial.yaml ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ save_path: saved_standard_challenging_context32_nocond_cont_cont_all_cont_eval
2
+
3
+ model:
4
+ base_learning_rate: 8.0e-05
5
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
6
+ params:
7
+ linear_start: 0.0015
8
+ linear_end: 0.0195
9
+ num_timesteps_cond: 1
10
+ log_every_t: 200
11
+ timesteps: 1000
12
+ first_stage_key: image
13
+ cond_stage_key: action_
14
+ scheduler_sampling_rate: 0.0
15
+ hybrid_key: c_concat
16
+ image_size: [64, 48]
17
+ channels: 3
18
+ cond_stage_trainable: false
19
+ conditioning_key: hybrid
20
+ monitor: val/loss_simple_ema
21
+
22
+ unet_config:
23
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
24
+ params:
25
+ image_size: [64, 48]
26
+ in_channels: 48
27
+ out_channels: 16
28
+ model_channels: 192
29
+ attention_resolutions:
30
+ - 2
31
+ - 4
32
+ - 8
33
+ num_res_blocks: 2
34
+ channel_mult:
35
+ - 1
36
+ - 2
37
+ - 3
38
+ - 5
39
+ num_head_channels: 32
40
+ use_spatial_transformer: false
41
+ transformer_depth: 1
42
+
43
+ temporal_encoder_config:
44
+ target: ldm.modules.encoders.temporal_encoder.TemporalEncoder
45
+ params:
46
+ input_channels: 16
47
+ hidden_size: 4096
48
+ num_layers: 1
49
+ dropout: 0.1
50
+ output_channels: 32
51
+ output_height: 48
52
+ output_width: 64
53
+
54
+ first_stage_config:
55
+ target: ldm.models.autoencoder.AutoencoderKL
56
+ params:
57
+ embed_dim: 16
58
+ monitor: val/rec_loss
59
+ ddconfig:
60
+ double_z: true
61
+ z_channels: 16
62
+ resolution: 256
63
+ in_channels: 3
64
+ out_ch: 3
65
+ ch: 128
66
+ ch_mult:
67
+ - 1
68
+ - 2
69
+ - 4
70
+ - 4
71
+ num_res_blocks: 2
72
+ attn_resolutions: []
73
+ dropout: 0.0
74
+ lossconfig:
75
+ target: torch.nn.Identity
76
+
77
+ cond_stage_config: __is_unconditional__
78
+
79
+ data:
80
+ target: data.data_processing.datasets.DataModule
81
+ params:
82
+ batch_size: 8
83
+ num_workers: 1
84
+ wrap: false
85
+ shuffle: True
86
+ drop_last: True
87
+ pin_memory: True
88
+ prefetch_factor: 2
89
+ persistent_workers: True
90
+ train:
91
+ target: data.data_processing.datasets.ActionsData
92
+ params:
93
+ data_csv_path: desktop_sequences_filtered_with_desktop_1.5k.challenging.train.target_frames.csv
94
+ normalization: standard
95
+ context_length: 32
96
+ #validation:
97
+ # target: data.data_processing.datasets.ActionsData
98
+ # params:
99
+
100
+ lightning:
101
+ trainer:
102
+ benchmark: False
103
+ max_epochs: 6400
104
+ limit_val_batches: 0
105
+ accelerator: gpu
106
+ gpus: 1
107
+ accumulate_grad_batches: 999999
108
+ gradient_clip_val: 1
109
+ checkpoint_callback: True
main.py CHANGED
@@ -42,6 +42,10 @@ MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-54k"
42
 
43
 
44
 
 
 
 
 
45
  print (f'setting: DEBUG_MODE: {DEBUG_MODE}, DEBUG_MODE_2: {DEBUG_MODE_2}, NUM_MAX_FRAMES: {NUM_MAX_FRAMES}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}, MODEL_NAME: {MODEL_NAME}')
46
 
47
  with open('latent_stats.json', 'r') as f:
@@ -54,7 +58,8 @@ LATENT_DIMS = (16, SCREEN_HEIGHT // 8, SCREEN_WIDTH // 8)
54
  #model = initialize_model("config_rnn.yaml", "yuntian-deng/computer-model")
55
  #model = initialize_model("config_final_model.yaml", "yuntian-deng/computer-model-noss")
56
  #model = initialize_model("config_final_model.yaml", "yuntian-deng/computer-model")
57
- model = initialize_model("config_final_model.yaml", MODEL_NAME)
 
58
 
59
  model = model.to(device)
60
  #model = torch.compile(model)
 
42
 
43
 
44
 
45
+ MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-unfreezernn-160k"
46
+ MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-368k"
47
+
48
+
49
  print (f'setting: DEBUG_MODE: {DEBUG_MODE}, DEBUG_MODE_2: {DEBUG_MODE_2}, NUM_MAX_FRAMES: {NUM_MAX_FRAMES}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}, MODEL_NAME: {MODEL_NAME}')
50
 
51
  with open('latent_stats.json', 'r') as f:
 
58
  #model = initialize_model("config_rnn.yaml", "yuntian-deng/computer-model")
59
  #model = initialize_model("config_final_model.yaml", "yuntian-deng/computer-model-noss")
60
  #model = initialize_model("config_final_model.yaml", "yuntian-deng/computer-model")
61
+ #model = initialize_model("config_final_model.yaml", MODEL_NAME)
62
+ model = initialize_model("config_final_model_origunet_nospatial.yaml", MODEL_NAME)
63
 
64
  model = model.to(device)
65
  #model = torch.compile(model)