File size: 4,504 Bytes
286963d
f8ecbbc
286963d
 
 
 
f8ecbbc
 
 
 
286963d
 
f8ecbbc
286963d
 
 
 
 
 
 
 
 
f8ecbbc
286963d
f8ecbbc
 
 
286963d
f8ecbbc
 
 
 
286963d
f8ecbbc
 
 
286963d
 
 
 
 
 
 
 
 
f8ecbbc
286963d
 
f8ecbbc
 
 
 
 
286963d
f8ecbbc
 
 
286963d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8ecbbc
286963d
 
f8ecbbc
 
 
 
 
 
 
 
 
 
 
286963d
 
f8ecbbc
286963d
f8ecbbc
 
286963d
 
f8ecbbc
286963d
f8ecbbc
286963d
f8ecbbc
 
286963d
f8ecbbc
 
286963d
f8ecbbc
286963d
 
 
f8ecbbc
 
 
286963d
f8ecbbc
 
286963d
f8ecbbc
 
 
286963d
 
 
f8ecbbc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from __future__ import print_function
from ortools.constraint_solver import routing_enums_pb2, pywrapcp
import torch
from google_solver.convert_data import convert_data


class GoogleActor:
    """
    Wrapper class to evaluate VRP solutions using Google's OR-Tools solver.
    """

    def __init__(self, scale_factor=100):
        self.scale_factor = scale_factor if scale_factor is not None else 1

    def __call__(self, input):
        drive_times = []
        data = convert_data(input, self.scale_factor)
        for datum in data:
            routing, assignment = self.compute_route(datum)
            total_time = self.compute_total_time(datum, routing, assignment)
            drive_times.append(total_time)

        return torch.tensor(drive_times).float()

    def compute_total_time(self, data, routing, assignment):
        """
        Computes the total time spent across all routes.

        Args:
            data (dict): Problem data with time matrix and vehicle count.
            routing (RoutingModel): OR-Tools routing model.
            assignment (Assignment): OR-Tools assignment solution.

        Returns:
            float: Total time (scaled back).
        """
        time_dimension = routing.GetDimensionOrDie('Time')
        total_time = 0
        for vehicle_id in range(data['num_vehicles']):
            index = routing.Start(vehicle_id)
            while not routing.IsEnd(index):
                index = assignment.Value(routing.NextVar(index))
            time_var = time_dimension.CumulVar(index)
            total_time += assignment.Min(time_var)

        return total_time / self.scale_factor

    def compute_route(self, input):
        """
        Solves the routing problem using OR-Tools.

        Args:
            input (dict): Data containing distance, time matrix, time windows, and depot index.

        Returns:
            RoutingModel, Assignment: OR-Tools routing and solution.
        """
        distance_matrix = input['distance_matrix']
        time_matrix = input['time_matrix']
        time_windows = input['time_windows']
        num_vehicles = input['num_vehicles']
        depot = input['depot']

        manager = pywrapcp.RoutingIndexManager(len(time_matrix), num_vehicles, depot)
        routing = pywrapcp.RoutingModel(manager)

        def time_callback(from_index, to_index):
            from_node = manager.IndexToNode(from_index)
            to_node = manager.IndexToNode(to_index)
            return time_matrix[from_node][to_node]

        transit_callback_index = routing.RegisterTransitCallback(time_callback)
        routing.SetArcCostEvaluatorOfAllVehicles(transit_callback_index)

        routing.AddDimension(
            transit_callback_index,
            10000,  # Allow waiting time
            10000,  # Max time per vehicle
            False,  # Don't force start cumul to zero
            'Time'
        )

        time_dimension = routing.GetDimensionOrDie('Time')

        # Time windows for all locations except depot
        for location_idx, (start, end) in enumerate(time_windows):
            if location_idx == depot:
                continue
            index = manager.NodeToIndex(location_idx)
            time_dimension.CumulVar(index).SetRange(int(start), int(end))

        # Time windows for vehicle start (depot)
        depot_start, depot_end = time_windows[depot]
        for vehicle_id in range(num_vehicles):
            index = routing.Start(vehicle_id)
            time_dimension.CumulVar(index).SetRange(int(depot_start), int(depot_end))

        # Finalizer hints for optimization
        for i in range(num_vehicles):
            routing.AddVariableMinimizedByFinalizer(time_dimension.CumulVar(routing.Start(i)))
            routing.AddVariableMinimizedByFinalizer(time_dimension.CumulVar(routing.End(i)))

        search_params = pywrapcp.DefaultRoutingSearchParameters()
        search_params.first_solution_strategy = routing_enums_pb2.FirstSolutionStrategy.AUTOMATIC

        assignment = routing.SolveWithParameters(search_params)
        return routing, assignment


def evaluate_google_model(validation_dataset):
    """
    Evaluate the validation dataset using Google OR-Tools model.

    Args:
        validation_dataset (Dataset): A dataset with a get_data method.

    Returns:
        torch.Tensor: Scores for each batch item.
    """
    validation_dataset.device = 'cpu'
    data = validation_dataset.get_data()
    model = GoogleActor(scale_factor=100)
    return model(data)