File size: 4,614 Bytes
ac59957
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import argparse
import os
import cv2

import torch

from config.parser import parse_args

from core import datasets
from core.memfof_lit import MEMFOFLit
from tqdm import tqdm

from core.utils.flow_viz import flow_to_image
from core.utils import frame_utils


@torch.inference_mode()
def create_spring_submission(model: MEMFOFLit, device: str, output_path: str):
    """Create submission for the Spring leaderboard"""
    test_dataset = datasets.three_frame_wrapper_spring_submission(
        datasets.SpringFlowDataset, {"split": "submission"}
    )
    for test_id in tqdm(range(len(test_dataset))):
        images, extra_info = test_dataset[test_id]
        scene, frame, _, frames, _ = extra_info
        images = images.unsqueeze(0).to(device)

        flow, _ = model.scale_and_forward_flow(images, scale=0)

        flow = flow[0].permute(1, 2, 0).cpu().numpy()
        flow_gt_vis = flow_to_image(flow, convert_to_bgr=True)

        cam = frames[0][1]
        if frames[0][0] < 0:
            direction = "FW"
        else:
            direction = "BW"

        output_dir = os.path.join(output_path, scene, f"flow_{direction}_{cam}")
        output_file = os.path.join(
            output_dir, f"flow_{direction}_{cam}_{frame + 1:04d}.flo5"
        )
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        cv2.imwrite(
            os.path.join(output_dir, f"flow_{direction}_{cam}_{frame + 1:04d}.png"),
            flow_gt_vis,
        )
        frame_utils.writeFlo5File(flow, output_file)


@torch.inference_mode()
def create_sintel_submission(model: MEMFOFLit, device: str, output_path: str):
    """Create submission for the Sintel leaderboard"""
    for dstype in ["clean", "final"]:
        test_dataset = datasets.three_frame_wrapper_sintel_submission(
            datasets.MpiSintel, {"split": "submission", "dstype": dstype}
        )
        for test_id in tqdm(range(len(test_dataset))):
            images, extra_info = test_dataset[test_id]
            scene, frame, _, _, _ = extra_info
            images = images.unsqueeze(0).to(device)

            flow, _ = model.scale_and_forward_flow(images, scale=1)
            flow = flow[0].permute(1, 2, 0).cpu().numpy()
            flow_gt_vis = flow_to_image(flow, convert_to_bgr=True)

            output_dir = os.path.join(output_path, dstype, scene)
            output_file = os.path.join(output_dir, "frame%04d.flo" % (frame + 1))

            if not os.path.exists(output_dir):
                os.makedirs(output_dir)

            frame_utils.writeFlow(output_file, flow)
            cv2.imwrite(os.path.join(output_dir, f"frame{frame + 1}.png"), flow_gt_vis)


@torch.inference_mode()
def create_kitti_submission(model: MEMFOFLit, device: str, output_path):
    """Create submission for the Sintel leaderboard"""
    test_dataset = datasets.three_frame_wrapper_kitti_submission(
        datasets.KITTI, {"split": "submission", "aug_params": None}
    )

    if not os.path.exists(output_path):
        os.makedirs(output_path)

    for test_id in tqdm(range(len(test_dataset))):
        images, _ = test_dataset[test_id]
        frame = f"{test_id:06d}_10.png"
        images = images.unsqueeze(0).to(device)

        flow, _ = model.scale_and_forward_flow(images, scale=1)
        flow = flow[0].permute(1, 2, 0).cpu().numpy()
        flow_gt_vis = flow_to_image(flow, convert_to_bgr=True)

        output_filename = os.path.join(output_path, frame)
        cv2.imwrite(os.path.join(output_path, f"frame{frame}"), flow_gt_vis)
        frame_utils.writeFlowKITTI(output_filename, flow)


@torch.inference_mode()
def eval(args):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = MEMFOFLit(args).to(device).eval()
    output_path = os.path.join(args.output_dir, args.dataset)

    if args.dataset == "spring":
        create_spring_submission(model, device, output_path)
    elif args.dataset == "sintel":
        create_sintel_submission(model, device, output_path)
    elif args.dataset == "kitti":
        create_kitti_submission(model, device, output_path)
    else:
        raise ValueError(f"Unkown dataset {args.dataset} requested for evaluation")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("output_dir", type=str, help="Saving path for checkpoints", nargs="?", default="submissions")
    parser.add_argument("--cfg", help="experiment config file name", required=True, type=str)
    args = parser.parse_args()
    args = parse_args(parser)
    os.makedirs(args.output_dir, exist_ok=True)
    eval(args)


if __name__ == "__main__":
    main()