Spaces:
Runtime error
Runtime error
| from __future__ import print_function | |
| from ortools.constraint_solver import routing_enums_pb2, pywrapcp | |
| import torch | |
| from google_solver.convert_data import convert_data | |
| class GoogleActor: | |
| """ | |
| Wrapper class to evaluate VRP solutions using Google's OR-Tools solver. | |
| """ | |
| def __init__(self, scale_factor=100): | |
| self.scale_factor = scale_factor if scale_factor is not None else 1 | |
| def __call__(self, input): | |
| drive_times = [] | |
| data = convert_data(input, self.scale_factor) | |
| for datum in data: | |
| routing, assignment = self.compute_route(datum) | |
| total_time = self.compute_total_time(datum, routing, assignment) | |
| drive_times.append(total_time) | |
| return torch.tensor(drive_times).float() | |
| def compute_total_time(self, data, routing, assignment): | |
| """ | |
| Computes the total time spent across all routes. | |
| Args: | |
| data (dict): Problem data with time matrix and vehicle count. | |
| routing (RoutingModel): OR-Tools routing model. | |
| assignment (Assignment): OR-Tools assignment solution. | |
| Returns: | |
| float: Total time (scaled back). | |
| """ | |
| time_dimension = routing.GetDimensionOrDie('Time') | |
| total_time = 0 | |
| for vehicle_id in range(data['num_vehicles']): | |
| index = routing.Start(vehicle_id) | |
| while not routing.IsEnd(index): | |
| index = assignment.Value(routing.NextVar(index)) | |
| time_var = time_dimension.CumulVar(index) | |
| total_time += assignment.Min(time_var) | |
| return total_time / self.scale_factor | |
| def compute_route(self, input): | |
| """ | |
| Solves the routing problem using OR-Tools. | |
| Args: | |
| input (dict): Data containing distance, time matrix, time windows, and depot index. | |
| Returns: | |
| RoutingModel, Assignment: OR-Tools routing and solution. | |
| """ | |
| distance_matrix = input['distance_matrix'] | |
| time_matrix = input['time_matrix'] | |
| time_windows = input['time_windows'] | |
| num_vehicles = input['num_vehicles'] | |
| depot = input['depot'] | |
| manager = pywrapcp.RoutingIndexManager(len(time_matrix), num_vehicles, depot) | |
| routing = pywrapcp.RoutingModel(manager) | |
| def time_callback(from_index, to_index): | |
| from_node = manager.IndexToNode(from_index) | |
| to_node = manager.IndexToNode(to_index) | |
| return time_matrix[from_node][to_node] | |
| transit_callback_index = routing.RegisterTransitCallback(time_callback) | |
| routing.SetArcCostEvaluatorOfAllVehicles(transit_callback_index) | |
| routing.AddDimension( | |
| transit_callback_index, | |
| 10000, # Allow waiting time | |
| 10000, # Max time per vehicle | |
| False, # Don't force start cumul to zero | |
| 'Time' | |
| ) | |
| time_dimension = routing.GetDimensionOrDie('Time') | |
| # Time windows for all locations except depot | |
| for location_idx, (start, end) in enumerate(time_windows): | |
| if location_idx == depot: | |
| continue | |
| index = manager.NodeToIndex(location_idx) | |
| time_dimension.CumulVar(index).SetRange(int(start), int(end)) | |
| # Time windows for vehicle start (depot) | |
| depot_start, depot_end = time_windows[depot] | |
| for vehicle_id in range(num_vehicles): | |
| index = routing.Start(vehicle_id) | |
| time_dimension.CumulVar(index).SetRange(int(depot_start), int(depot_end)) | |
| # Finalizer hints for optimization | |
| for i in range(num_vehicles): | |
| routing.AddVariableMinimizedByFinalizer(time_dimension.CumulVar(routing.Start(i))) | |
| routing.AddVariableMinimizedByFinalizer(time_dimension.CumulVar(routing.End(i))) | |
| search_params = pywrapcp.DefaultRoutingSearchParameters() | |
| search_params.first_solution_strategy = routing_enums_pb2.FirstSolutionStrategy.AUTOMATIC | |
| assignment = routing.SolveWithParameters(search_params) | |
| return routing, assignment | |
| def evaluate_google_model(validation_dataset): | |
| """ | |
| Evaluate the validation dataset using Google OR-Tools model. | |
| Args: | |
| validation_dataset (Dataset): A dataset with a get_data method. | |
| Returns: | |
| torch.Tensor: Scores for each batch item. | |
| """ | |
| validation_dataset.device = 'cpu' | |
| data = validation_dataset.get_data() | |
| model = GoogleActor(scale_factor=100) | |
| return model(data) | |