da03 commited on
Commit
e0f0dd9
·
1 Parent(s): 832f927
Files changed (1) hide show
  1. main.py +13 -8
main.py CHANGED
@@ -139,14 +139,19 @@ def _process_frame_sync(model, inputs):
139
 
140
  # UNet sampling
141
  start = time.perf_counter()
142
- sampler = DDIMSampler(model)
143
- sample_latent, _ = sampler.sample(
144
- S=NUM_SAMPLING_STEPS,
145
- conditioning={'c_concat': output_from_rnn},
146
- batch_size=1,
147
- shape=LATENT_DIMS,
148
- verbose=False
149
- )
 
 
 
 
 
150
  timing['unet'] = time.perf_counter() - start
151
 
152
  # Decoding
 
139
 
140
  # UNet sampling
141
  start = time.perf_counter()
142
+ use_rnn = True
143
+ print (f"use_rnn: {use_rnn}")
144
+ if use_rnn:
145
+ sample_latent = output_from_rnn[:, :16]
146
+ else:
147
+ sampler = DDIMSampler(model)
148
+ sample_latent, _ = sampler.sample(
149
+ S=NUM_SAMPLING_STEPS,
150
+ conditioning={'c_concat': output_from_rnn},
151
+ batch_size=1,
152
+ shape=LATENT_DIMS,
153
+ verbose=False
154
+ )
155
  timing['unet'] = time.perf_counter() - start
156
 
157
  # Decoding