a-ragab-h-m's picture
Update google_solver/convert_data.py
b233689 verified
import torch
import numpy as np
def convert_tensor(x):
"""
Convert a PyTorch tensor to a nested Python list of integers.
Args:
x (torch.Tensor): Tensor to convert.
Returns:
list: Converted nested list.
"""
x = x.long().cpu().numpy().astype(int)
if x.ndim == 1:
return list(x)
return [list(row) for row in x]
def make_time_windows(start_time, end_time):
"""
Concatenate start and end time tensors to form time windows.
Args:
start_time (torch.Tensor): Start times (B, N, 1)
end_time (torch.Tensor): End times (B, N, 1)
Returns:
torch.Tensor: Time windows (B, N, 2)
"""
return torch.cat([start_time, end_time], dim=2)
def convert_data(input_data, scale_factor):
"""
Convert batched graph and fleet data to OR-Tools compatible format.
Args:
input_data (tuple): Tuple of (graph_data, fleet_data) as dictionaries.
scale_factor (float): Scaling factor to convert float to integer.
Returns:
list: List of dictionaries, one per batch item, containing:
- distance_matrix
- time_matrix
- time_windows
- depot index (default 0)
- num_vehicles
"""
graph_data, fleet_data = input_data
start_times = graph_data['start_times']
end_times = graph_data['end_times']
distance_matrix = graph_data['distance_matrix']
time_matrix = graph_data['time_matrix']
time_windows = make_time_windows(start_times, end_times)
batch_size = distance_matrix.size(0)
converted_data = []
for i in range(batch_size):
space_mat = (distance_matrix[i] * scale_factor)
time_mat = (time_matrix[i] * scale_factor)
windows = (time_windows[i] * scale_factor)
sample_dict = {
'distance_matrix': convert_tensor(space_mat),
'time_matrix': convert_tensor(time_mat),
'time_windows': convert_tensor(windows),
'depot': 0,
'num_vehicles': distance_matrix[i].shape[1] # assuming square matrix
}
converted_data.append(sample_dict)
return converted_data