Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
e4cd0fb
1
Parent(s):
dffe378
- online_data_generation.py +10 -8
online_data_generation.py
CHANGED
@@ -533,14 +533,16 @@ def main():
|
|
533 |
"""Main function to run the data processing pipeline."""
|
534 |
|
535 |
# create a padding image first
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
|
|
|
|
544 |
# Initialize database
|
545 |
initialize_database()
|
546 |
|
|
|
533 |
"""Main function to run the data processing pipeline."""
|
534 |
|
535 |
# create a padding image first
|
536 |
+
if not os.path.exists(os.path.join(OUTPUT_DIR, 'padding.npy')):
|
537 |
+
logger.info("Creating padding image...")
|
538 |
+
padding_data = np.zeros((SCREEN_HEIGHT, SCREEN_WIDTH, 3), dtype=np.uint8)
|
539 |
+
padding_tensor = torch.tensor(padding_data).unsqueeze(0)
|
540 |
+
padding_tensor = rearrange(padding_tensor, 'b h w c -> b c h w').to(device)
|
541 |
+
posterior = autoencoder.encode(padding_tensor)
|
542 |
+
latent = posterior.sample()
|
543 |
+
latent = torch.zeros_like(latent).squeeze(0)
|
544 |
+
np.save(os.path.join(OUTPUT_DIR, 'padding.npy.tmp'), latent.cpu().numpy())
|
545 |
+
os.rename(os.path.join(OUTPUT_DIR, 'padding.npy.tmp'), os.path.join(OUTPUT_DIR, 'padding.npy'))
|
546 |
# Initialize database
|
547 |
initialize_database()
|
548 |
|