da03 commited on
Commit
4896769
·
1 Parent(s): 8479652
Files changed (1) hide show
  1. main.py +12 -9
main.py CHANGED
@@ -148,15 +148,18 @@ def _process_frame_sync(model, inputs):
148
  if use_rnn:
149
  sample_latent = output_from_rnn[:, :16]
150
  else:
151
- NUM_SAMPLING_STEPS = 32
152
- sampler = DDIMSampler(model)
153
- sample_latent, _ = sampler.sample(
154
- S=NUM_SAMPLING_STEPS,
155
- conditioning={'c_concat': output_from_rnn},
156
- batch_size=1,
157
- shape=LATENT_DIMS,
158
- verbose=False
159
- )
 
 
 
160
  timing['unet'] = time.perf_counter() - start
161
 
162
  # Decoding
 
148
  if use_rnn:
149
  sample_latent = output_from_rnn[:, :16]
150
  else:
151
+ NUM_SAMPLING_STEPS = 1000
152
+ if NUM_SAMPLING_STEPS >= 1000:
153
+ sample_latent = model.p_sample_loop(cond={'c_concat': output_from_rnn}, shape=[1, *LATENT_DIMS], return_intermediates=False, verbose=True)
154
+ else:
155
+ sampler = DDIMSampler(model)
156
+ sample_latent, _ = sampler.sample(
157
+ S=NUM_SAMPLING_STEPS,
158
+ conditioning={'c_concat': output_from_rnn},
159
+ batch_size=1,
160
+ shape=LATENT_DIMS,
161
+ verbose=False
162
+ )
163
  timing['unet'] = time.perf_counter() - start
164
 
165
  # Decoding