import gc from pathlib import Path import numpy as np import pandas as pd from datasets import load_dataset from tqdm import tqdm from solution import predict_wireframe def empty_solution(): """Return a minimal valid solution in case of an error.""" return np.zeros((2, 3)), [] def main(): """ Main script for the S23DR 2025 Challenge. This script loads the test dataset using the competition's specific method, runs the prediction pipeline, and saves the results. """ print("------------ Setting up data paths ------------") # This is the essential path where data is stored in the submission environment. data_path = Path('/tmp/data') print("------------ Loading dataset ------------") # This data loading logic is preserved from the original script to ensure # compatibility with the submission environment. data_files = { "validation": [str(p) for p in data_path.rglob('*public*/**/*.tar')], "test": [str(p) for p in data_path.rglob('*private*/**/*.tar')], } print(f"Found data files: {data_files}") dataset = load_dataset( str(data_path / 'hoho25k_test_x.py'), data_files=data_files, trust_remote_code=True, writer_batch_size=100, ) print(f"Dataset loaded successfully: {dataset}") print('------------ Starting prediction loop ---------------') solution = [] for subset_name in dataset.keys(): print(f"Predicting for subset: {subset_name}") for i, entry in enumerate(tqdm(dataset[subset_name], desc=f"Processing {subset_name}")): try: # Run your prediction pipeline pred_vertices, pred_edges = predict_wireframe(entry) except Exception as e: # If your pipeline fails, provide an empty solution and log the error. print(f"Error processing sample {entry.get('order_id', 'UNKNOWN')}: {e}") pred_vertices, pred_edges = empty_solution() # Append the result in the required format. solution.append( { 'order_id': entry['order_id'], 'wf_vertices': pred_vertices.tolist(), 'wf_edges': pred_edges, } ) # Periodically run garbage collection to manage memory. if (i + 1) % 50 == 0: gc.collect() print('------------ Saving results ---------------') sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"]) sub.to_parquet("submission.parquet", index=False) print("------------ Done ------------") if __name__ == "__main__": main()