da03 commited on
Commit
e4cd0fb
·
1 Parent(s): dffe378
Files changed (1) hide show
  1. online_data_generation.py +10 -8
online_data_generation.py CHANGED
@@ -533,14 +533,16 @@ def main():
533
  """Main function to run the data processing pipeline."""
534
 
535
  # create a padding image first
536
- padding_data = np.zeros((SCREEN_HEIGHT, SCREEN_WIDTH, 3), dtype=np.uint8)
537
- padding_tensor = torch.tensor(padding_data).unsqueeze(0)
538
- padding_tensor = rearrange(padding_tensor, 'b h w c -> b c h w').to(device)
539
- posterior = autoencoder.encode(padding_tensor)
540
- latent = posterior.sample()
541
- latent = torch.zeros_like(latent).squeeze(0)
542
- np.save(os.path.join(OUTPUT_DIR, 'padding.npy'), latent.cpu().numpy())
543
-
 
 
544
  # Initialize database
545
  initialize_database()
546
 
 
533
  """Main function to run the data processing pipeline."""
534
 
535
  # create a padding image first
536
+ if not os.path.exists(os.path.join(OUTPUT_DIR, 'padding.npy')):
537
+ logger.info("Creating padding image...")
538
+ padding_data = np.zeros((SCREEN_HEIGHT, SCREEN_WIDTH, 3), dtype=np.uint8)
539
+ padding_tensor = torch.tensor(padding_data).unsqueeze(0)
540
+ padding_tensor = rearrange(padding_tensor, 'b h w c -> b c h w').to(device)
541
+ posterior = autoencoder.encode(padding_tensor)
542
+ latent = posterior.sample()
543
+ latent = torch.zeros_like(latent).squeeze(0)
544
+ np.save(os.path.join(OUTPUT_DIR, 'padding.npy.tmp'), latent.cpu().numpy())
545
+ os.rename(os.path.join(OUTPUT_DIR, 'padding.npy.tmp'), os.path.join(OUTPUT_DIR, 'padding.npy'))
546
  # Initialize database
547
  initialize_database()
548