|
import torch |
|
import numpy as np |
|
|
|
|
|
def convert_tensor(x): |
|
""" |
|
Convert a PyTorch tensor to a nested Python list of integers. |
|
|
|
Args: |
|
x (torch.Tensor): Tensor to convert. |
|
|
|
Returns: |
|
list: Converted nested list. |
|
""" |
|
x = x.long().cpu().numpy().astype(int) |
|
if x.ndim == 1: |
|
return list(x) |
|
return [list(row) for row in x] |
|
|
|
|
|
def make_time_windows(start_time, end_time): |
|
""" |
|
Concatenate start and end time tensors to form time windows. |
|
|
|
Args: |
|
start_time (torch.Tensor): Start times (B, N, 1) |
|
end_time (torch.Tensor): End times (B, N, 1) |
|
|
|
Returns: |
|
torch.Tensor: Time windows (B, N, 2) |
|
""" |
|
return torch.cat([start_time, end_time], dim=2) |
|
|
|
|
|
def convert_data(input_data, scale_factor): |
|
""" |
|
Convert batched graph and fleet data to OR-Tools compatible format. |
|
|
|
Args: |
|
input_data (tuple): Tuple of (graph_data, fleet_data) as dictionaries. |
|
scale_factor (float): Scaling factor to convert float to integer. |
|
|
|
Returns: |
|
list: List of dictionaries, one per batch item, containing: |
|
- distance_matrix |
|
- time_matrix |
|
- time_windows |
|
- depot index (default 0) |
|
- num_vehicles |
|
""" |
|
graph_data, fleet_data = input_data |
|
|
|
start_times = graph_data['start_times'] |
|
end_times = graph_data['end_times'] |
|
distance_matrix = graph_data['distance_matrix'] |
|
time_matrix = graph_data['time_matrix'] |
|
time_windows = make_time_windows(start_times, end_times) |
|
|
|
batch_size = distance_matrix.size(0) |
|
converted_data = [] |
|
|
|
for i in range(batch_size): |
|
space_mat = (distance_matrix[i] * scale_factor) |
|
time_mat = (time_matrix[i] * scale_factor) |
|
windows = (time_windows[i] * scale_factor) |
|
|
|
sample_dict = { |
|
'distance_matrix': convert_tensor(space_mat), |
|
'time_matrix': convert_tensor(time_mat), |
|
'time_windows': convert_tensor(windows), |
|
'depot': 0, |
|
'num_vehicles': distance_matrix[i].shape[1] |
|
} |
|
|
|
converted_data.append(sample_dict) |
|
|
|
return converted_data |
|
|