da03 commited on
Commit
04846cf
·
1 Parent(s): 2d60859
Files changed (1) hide show
  1. main.py +21 -6
main.py CHANGED
@@ -25,7 +25,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
  DEBUG_MODE = False
26
  DEBUG_MODE_2 = False
27
  NUM_MAX_FRAMES = 1
28
-
29
  SCREEN_WIDTH = 512
30
  SCREEN_HEIGHT = 384
31
  NUM_SAMPLING_STEPS = 32
@@ -48,8 +48,15 @@ MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd
48
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-unfreezernn-198k"
49
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-674k"
50
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-74k"
51
- MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-x0-22k"
52
  MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-online-70k"
 
 
 
 
 
 
 
 
53
 
54
 
55
  print (f'setting: DEBUG_MODE: {DEBUG_MODE}, DEBUG_MODE_2: {DEBUG_MODE_2}, NUM_MAX_FRAMES: {NUM_MAX_FRAMES}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}, MODEL_NAME: {MODEL_NAME}')
@@ -67,9 +74,17 @@ LATENT_DIMS = (16, SCREEN_HEIGHT // 8, SCREEN_WIDTH // 8)
67
 
68
  if 'origunet' in MODEL_NAME:
69
  if 'x0' in MODEL_NAME:
70
- model = initialize_model("config_final_model_origunet_nospatial_x0.yaml", MODEL_NAME)
 
 
 
 
71
  else:
72
- model = initialize_model("config_final_model_origunet_nospatial.yaml", MODEL_NAME)
 
 
 
 
73
  else:
74
  model = initialize_model("config_final_model.yaml", MODEL_NAME)
75
 
@@ -205,12 +220,12 @@ def _process_frame_sync(model, inputs, use_rnn, num_sampling_steps):
205
  sample_latent = output_from_rnn[:, :16]
206
  else:
207
  #NUM_SAMPLING_STEPS = 8
208
- if num_sampling_steps >= 1000:
209
  sample_latent = model.p_sample_loop(cond={'c_concat': output_from_rnn}, shape=[1, *LATENT_DIMS], return_intermediates=False, verbose=True)
210
  else:
211
  if num_sampling_steps == 1:
212
  x = torch.randn([1, *LATENT_DIMS], device=device)
213
- t = torch.full((1,), 999, device=device, dtype=torch.long)
214
  sample_latent = model.apply_model(x, t, {'c_concat': output_from_rnn})
215
  else:
216
  sampler = DDIMSampler(model)
 
25
  DEBUG_MODE = False
26
  DEBUG_MODE_2 = False
27
  NUM_MAX_FRAMES = 1
28
+ TIMESTEPS = 1000
29
  SCREEN_WIDTH = 512
30
  SCREEN_HEIGHT = 384
31
  NUM_SAMPLING_STEPS = 32
 
48
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-unfreezernn-198k"
49
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-674k"
50
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-74k"
 
51
  MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-online-70k"
52
+ MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-x0-46k"
53
+ MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-142k"
54
+ MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-338k"
55
+ MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-ddpm32-x0-140k"
56
+ MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-ddpm32-eps-144k"
57
+ MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-70k"
58
+ MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-joint-onlineonly-eps22-40k"
59
+ MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-22-38k"
60
 
61
 
62
  print (f'setting: DEBUG_MODE: {DEBUG_MODE}, DEBUG_MODE_2: {DEBUG_MODE_2}, NUM_MAX_FRAMES: {NUM_MAX_FRAMES}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}, MODEL_NAME: {MODEL_NAME}')
 
74
 
75
  if 'origunet' in MODEL_NAME:
76
  if 'x0' in MODEL_NAME:
77
+ if 'ddpm32' in MODEL_NAME:
78
+ TIMESTEPS = 32
79
+ model = initialize_model("config_final_model_origunet_nospatial_x0_ddpm32.yaml", MODEL_NAME)
80
+ else:
81
+ model = initialize_model("config_final_model_origunet_nospatial_x0.yaml", MODEL_NAME)
82
  else:
83
+ if 'ddpm32' in MODEL_NAME:
84
+ TIMESTEPS = 32
85
+ model = initialize_model("config_final_model_origunet_nospatial_ddpm32.yaml", MODEL_NAME)
86
+ else:
87
+ model = initialize_model("config_final_model_origunet_nospatial.yaml", MODEL_NAME)
88
  else:
89
  model = initialize_model("config_final_model.yaml", MODEL_NAME)
90
 
 
220
  sample_latent = output_from_rnn[:, :16]
221
  else:
222
  #NUM_SAMPLING_STEPS = 8
223
+ if num_sampling_steps >= TIMESTEPS:
224
  sample_latent = model.p_sample_loop(cond={'c_concat': output_from_rnn}, shape=[1, *LATENT_DIMS], return_intermediates=False, verbose=True)
225
  else:
226
  if num_sampling_steps == 1:
227
  x = torch.randn([1, *LATENT_DIMS], device=device)
228
+ t = torch.full((1,), TIMESTEPS-1, device=device, dtype=torch.long)
229
  sample_latent = model.apply_model(x, t, {'c_concat': output_from_rnn})
230
  else:
231
  sampler = DDIMSampler(model)