from __future__ import print_function from ortools.constraint_solver import routing_enums_pb2, pywrapcp import torch from google_solver.convert_data import convert_data class GoogleActor: """ Wrapper class to evaluate VRP solutions using Google's OR-Tools solver. """ def __init__(self, scale_factor=100): self.scale_factor = scale_factor if scale_factor is not None else 1 def __call__(self, input): drive_times = [] data = convert_data(input, self.scale_factor) for datum in data: routing, assignment = self.compute_route(datum) total_time = self.compute_total_time(datum, routing, assignment) drive_times.append(total_time) return torch.tensor(drive_times).float() def compute_total_time(self, data, routing, assignment): """ Computes the total time spent across all routes. Args: data (dict): Problem data with time matrix and vehicle count. routing (RoutingModel): OR-Tools routing model. assignment (Assignment): OR-Tools assignment solution. Returns: float: Total time (scaled back). """ time_dimension = routing.GetDimensionOrDie('Time') total_time = 0 for vehicle_id in range(data['num_vehicles']): index = routing.Start(vehicle_id) while not routing.IsEnd(index): index = assignment.Value(routing.NextVar(index)) time_var = time_dimension.CumulVar(index) total_time += assignment.Min(time_var) return total_time / self.scale_factor def compute_route(self, input): """ Solves the routing problem using OR-Tools. Args: input (dict): Data containing distance, time matrix, time windows, and depot index. Returns: RoutingModel, Assignment: OR-Tools routing and solution. """ distance_matrix = input['distance_matrix'] time_matrix = input['time_matrix'] time_windows = input['time_windows'] num_vehicles = input['num_vehicles'] depot = input['depot'] manager = pywrapcp.RoutingIndexManager(len(time_matrix), num_vehicles, depot) routing = pywrapcp.RoutingModel(manager) def time_callback(from_index, to_index): from_node = manager.IndexToNode(from_index) to_node = manager.IndexToNode(to_index) return time_matrix[from_node][to_node] transit_callback_index = routing.RegisterTransitCallback(time_callback) routing.SetArcCostEvaluatorOfAllVehicles(transit_callback_index) routing.AddDimension( transit_callback_index, 10000, # Allow waiting time 10000, # Max time per vehicle False, # Don't force start cumul to zero 'Time' ) time_dimension = routing.GetDimensionOrDie('Time') # Time windows for all locations except depot for location_idx, (start, end) in enumerate(time_windows): if location_idx == depot: continue index = manager.NodeToIndex(location_idx) time_dimension.CumulVar(index).SetRange(int(start), int(end)) # Time windows for vehicle start (depot) depot_start, depot_end = time_windows[depot] for vehicle_id in range(num_vehicles): index = routing.Start(vehicle_id) time_dimension.CumulVar(index).SetRange(int(depot_start), int(depot_end)) # Finalizer hints for optimization for i in range(num_vehicles): routing.AddVariableMinimizedByFinalizer(time_dimension.CumulVar(routing.Start(i))) routing.AddVariableMinimizedByFinalizer(time_dimension.CumulVar(routing.End(i))) search_params = pywrapcp.DefaultRoutingSearchParameters() search_params.first_solution_strategy = routing_enums_pb2.FirstSolutionStrategy.AUTOMATIC assignment = routing.SolveWithParameters(search_params) return routing, assignment def evaluate_google_model(validation_dataset): """ Evaluate the validation dataset using Google OR-Tools model. Args: validation_dataset (Dataset): A dataset with a get_data method. Returns: torch.Tensor: Scores for each batch item. """ validation_dataset.device = 'cpu' data = validation_dataset.get_data() model = GoogleActor(scale_factor=100) return model(data)