a-ragab-h-m's picture
Update google_solver/google_model.py
f8ecbbc verified
from __future__ import print_function
from ortools.constraint_solver import routing_enums_pb2, pywrapcp
import torch
from google_solver.convert_data import convert_data
class GoogleActor:
"""
Wrapper class to evaluate VRP solutions using Google's OR-Tools solver.
"""
def __init__(self, scale_factor=100):
self.scale_factor = scale_factor if scale_factor is not None else 1
def __call__(self, input):
drive_times = []
data = convert_data(input, self.scale_factor)
for datum in data:
routing, assignment = self.compute_route(datum)
total_time = self.compute_total_time(datum, routing, assignment)
drive_times.append(total_time)
return torch.tensor(drive_times).float()
def compute_total_time(self, data, routing, assignment):
"""
Computes the total time spent across all routes.
Args:
data (dict): Problem data with time matrix and vehicle count.
routing (RoutingModel): OR-Tools routing model.
assignment (Assignment): OR-Tools assignment solution.
Returns:
float: Total time (scaled back).
"""
time_dimension = routing.GetDimensionOrDie('Time')
total_time = 0
for vehicle_id in range(data['num_vehicles']):
index = routing.Start(vehicle_id)
while not routing.IsEnd(index):
index = assignment.Value(routing.NextVar(index))
time_var = time_dimension.CumulVar(index)
total_time += assignment.Min(time_var)
return total_time / self.scale_factor
def compute_route(self, input):
"""
Solves the routing problem using OR-Tools.
Args:
input (dict): Data containing distance, time matrix, time windows, and depot index.
Returns:
RoutingModel, Assignment: OR-Tools routing and solution.
"""
distance_matrix = input['distance_matrix']
time_matrix = input['time_matrix']
time_windows = input['time_windows']
num_vehicles = input['num_vehicles']
depot = input['depot']
manager = pywrapcp.RoutingIndexManager(len(time_matrix), num_vehicles, depot)
routing = pywrapcp.RoutingModel(manager)
def time_callback(from_index, to_index):
from_node = manager.IndexToNode(from_index)
to_node = manager.IndexToNode(to_index)
return time_matrix[from_node][to_node]
transit_callback_index = routing.RegisterTransitCallback(time_callback)
routing.SetArcCostEvaluatorOfAllVehicles(transit_callback_index)
routing.AddDimension(
transit_callback_index,
10000, # Allow waiting time
10000, # Max time per vehicle
False, # Don't force start cumul to zero
'Time'
)
time_dimension = routing.GetDimensionOrDie('Time')
# Time windows for all locations except depot
for location_idx, (start, end) in enumerate(time_windows):
if location_idx == depot:
continue
index = manager.NodeToIndex(location_idx)
time_dimension.CumulVar(index).SetRange(int(start), int(end))
# Time windows for vehicle start (depot)
depot_start, depot_end = time_windows[depot]
for vehicle_id in range(num_vehicles):
index = routing.Start(vehicle_id)
time_dimension.CumulVar(index).SetRange(int(depot_start), int(depot_end))
# Finalizer hints for optimization
for i in range(num_vehicles):
routing.AddVariableMinimizedByFinalizer(time_dimension.CumulVar(routing.Start(i)))
routing.AddVariableMinimizedByFinalizer(time_dimension.CumulVar(routing.End(i)))
search_params = pywrapcp.DefaultRoutingSearchParameters()
search_params.first_solution_strategy = routing_enums_pb2.FirstSolutionStrategy.AUTOMATIC
assignment = routing.SolveWithParameters(search_params)
return routing, assignment
def evaluate_google_model(validation_dataset):
"""
Evaluate the validation dataset using Google OR-Tools model.
Args:
validation_dataset (Dataset): A dataset with a get_data method.
Returns:
torch.Tensor: Scores for each batch item.
"""
validation_dataset.device = 'cpu'
data = validation_dataset.get_data()
model = GoogleActor(scale_factor=100)
return model(data)