da03 commited on
Commit
2d60859
·
1 Parent(s): f0fd514
Files changed (1) hide show
  1. main.py +15 -9
main.py CHANGED
@@ -48,7 +48,8 @@ MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd
48
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-unfreezernn-198k"
49
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-674k"
50
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-74k"
51
- MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-x0"
 
52
 
53
 
54
  print (f'setting: DEBUG_MODE: {DEBUG_MODE}, DEBUG_MODE_2: {DEBUG_MODE_2}, NUM_MAX_FRAMES: {NUM_MAX_FRAMES}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}, MODEL_NAME: {MODEL_NAME}')
@@ -207,14 +208,19 @@ def _process_frame_sync(model, inputs, use_rnn, num_sampling_steps):
207
  if num_sampling_steps >= 1000:
208
  sample_latent = model.p_sample_loop(cond={'c_concat': output_from_rnn}, shape=[1, *LATENT_DIMS], return_intermediates=False, verbose=True)
209
  else:
210
- sampler = DDIMSampler(model)
211
- sample_latent, _ = sampler.sample(
212
- S=num_sampling_steps,
213
- conditioning={'c_concat': output_from_rnn},
214
- batch_size=1,
215
- shape=LATENT_DIMS,
216
- verbose=False
217
- )
 
 
 
 
 
218
  timing['unet'] = time.perf_counter() - start
219
 
220
  # Decoding
 
48
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-unfreezernn-198k"
49
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-674k"
50
  MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-74k"
51
+ MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-x0-22k"
52
+ MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-online-70k"
53
 
54
 
55
  print (f'setting: DEBUG_MODE: {DEBUG_MODE}, DEBUG_MODE_2: {DEBUG_MODE_2}, NUM_MAX_FRAMES: {NUM_MAX_FRAMES}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}, MODEL_NAME: {MODEL_NAME}')
 
208
  if num_sampling_steps >= 1000:
209
  sample_latent = model.p_sample_loop(cond={'c_concat': output_from_rnn}, shape=[1, *LATENT_DIMS], return_intermediates=False, verbose=True)
210
  else:
211
+ if num_sampling_steps == 1:
212
+ x = torch.randn([1, *LATENT_DIMS], device=device)
213
+ t = torch.full((1,), 999, device=device, dtype=torch.long)
214
+ sample_latent = model.apply_model(x, t, {'c_concat': output_from_rnn})
215
+ else:
216
+ sampler = DDIMSampler(model)
217
+ sample_latent, _ = sampler.sample(
218
+ S=num_sampling_steps,
219
+ conditioning={'c_concat': output_from_rnn},
220
+ batch_size=1,
221
+ shape=LATENT_DIMS,
222
+ verbose=False
223
+ )
224
  timing['unet'] = time.perf_counter() - start
225
 
226
  # Decoding