da03 commited on
Commit
f8d24b9
·
1 Parent(s): 8c27ef5
Files changed (2) hide show
  1. main.py +2 -2
  2. utils.py +1 -1
main.py CHANGED
@@ -220,10 +220,10 @@ def predict_next_frame(previous_frames, previous_actions: List[Tuple[str, List[i
220
  # Prepare the image sequence for the model
221
  #assert len(initial_images) == 32
222
  image_sequence = previous_frames[-32:] # Take the last 7 frames
223
- i = 1
224
  while len(image_sequence) < 32:
225
  image_sequence.insert(0, padding_image)
226
- i += 1
227
  #image_sequence.append(initial_images[len(image_sequence)])
228
 
229
  # Convert the image sequence to a tensor and concatenate in the channel dimension
 
220
  # Prepare the image sequence for the model
221
  #assert len(initial_images) == 32
222
  image_sequence = previous_frames[-32:] # Take the last 7 frames
223
+ #i = 1
224
  while len(image_sequence) < 32:
225
  image_sequence.insert(0, padding_image)
226
+ #i += 1
227
  #image_sequence.append(initial_images[len(image_sequence)])
228
 
229
  # Convert the image sequence to a tensor and concatenate in the channel dimension
utils.py CHANGED
@@ -83,7 +83,7 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
83
  if DDPM:
84
  samples_ddim = model.p_sample_loop(cond=c, shape=[1, 4, 48, 64], return_intermediates=False, verbose=True)
85
  else:
86
- samples_ddim, _ = sampler.sample(S=8,
87
  conditioning=c,
88
  batch_size=1,
89
  shape=[4, 48, 64],
 
83
  if DDPM:
84
  samples_ddim = model.p_sample_loop(cond=c, shape=[1, 4, 48, 64], return_intermediates=False, verbose=True)
85
  else:
86
+ samples_ddim, _ = sampler.sample(S=16,
87
  conditioning=c,
88
  batch_size=1,
89
  shape=[4, 48, 64],