File size: 2,700 Bytes
4999c45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gc
from pathlib import Path

import numpy as np
import pandas as pd
from datasets import load_dataset
from tqdm import tqdm
from solution import predict_wireframe


def empty_solution():
    """Return a minimal valid solution in case of an error."""
    return np.zeros((2, 3)), []


def main():
    """
    Main script for the S23DR 2025 Challenge.
    This script loads the test dataset using the competition's specific
    method, runs the prediction pipeline, and saves the results.
    """
    print("------------ Setting up data paths ------------")
    # This is the essential path where data is stored in the submission environment.
    data_path = Path('/tmp/data')

    print("------------ Loading dataset ------------")
    # This data loading logic is preserved from the original script to ensure
    # compatibility with the submission environment.
    data_files = {
        "validation": [str(p) for p in data_path.rglob('*public*/**/*.tar')],
        "test": [str(p) for p in data_path.rglob('*private*/**/*.tar')],
    }
    print(f"Found data files: {data_files}")

    dataset = load_dataset(
        str(data_path / 'hoho25k_test_x.py'),
        data_files=data_files,
        trust_remote_code=True,
        writer_batch_size=100,
    )
    print(f"Dataset loaded successfully: {dataset}")

    print('------------ Starting prediction loop ---------------')
    solution = []
    for subset_name in dataset.keys():
        print(f"Predicting for subset: {subset_name}")
        for i, entry in enumerate(tqdm(dataset[subset_name], desc=f"Processing {subset_name}")):
            try:
                # Run your prediction pipeline
                pred_vertices, pred_edges = predict_wireframe(entry)
            except Exception as e:
                # If your pipeline fails, provide an empty solution and log the error.
                print(f"Error processing sample {entry.get('order_id', 'UNKNOWN')}: {e}")
                pred_vertices, pred_edges = empty_solution()

            # Append the result in the required format.
            solution.append(
                {
                    'order_id': entry['order_id'],
                    'wf_vertices': pred_vertices.tolist(),
                    'wf_edges': pred_edges,
                }
            )

            # Periodically run garbage collection to manage memory.
            if (i + 1) % 50 == 0:
                gc.collect()

    print('------------ Saving results ---------------')
    sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"])
    sub.to_parquet("submission.parquet", index=False)
    print("------------ Done ------------")


if __name__ == "__main__":
    main()