Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
def convert_tensor(x): | |
""" | |
Convert a PyTorch tensor to a nested Python list of integers. | |
Args: | |
x (torch.Tensor): Tensor to convert. | |
Returns: | |
list: Converted nested list. | |
""" | |
x = x.long().cpu().numpy().astype(int) | |
if x.ndim == 1: | |
return list(x) | |
return [list(row) for row in x] | |
def make_time_windows(start_time, end_time): | |
""" | |
Concatenate start and end time tensors to form time windows. | |
Args: | |
start_time (torch.Tensor): Start times (B, N, 1) | |
end_time (torch.Tensor): End times (B, N, 1) | |
Returns: | |
torch.Tensor: Time windows (B, N, 2) | |
""" | |
return torch.cat([start_time, end_time], dim=2) | |
def convert_data(input_data, scale_factor): | |
""" | |
Convert batched graph and fleet data to OR-Tools compatible format. | |
Args: | |
input_data (tuple): Tuple of (graph_data, fleet_data) as dictionaries. | |
scale_factor (float): Scaling factor to convert float to integer. | |
Returns: | |
list: List of dictionaries, one per batch item, containing: | |
- distance_matrix | |
- time_matrix | |
- time_windows | |
- depot index (default 0) | |
- num_vehicles | |
""" | |
graph_data, fleet_data = input_data | |
start_times = graph_data['start_times'] | |
end_times = graph_data['end_times'] | |
distance_matrix = graph_data['distance_matrix'] | |
time_matrix = graph_data['time_matrix'] | |
time_windows = make_time_windows(start_times, end_times) | |
batch_size = distance_matrix.size(0) | |
converted_data = [] | |
for i in range(batch_size): | |
space_mat = (distance_matrix[i] * scale_factor) | |
time_mat = (time_matrix[i] * scale_factor) | |
windows = (time_windows[i] * scale_factor) | |
sample_dict = { | |
'distance_matrix': convert_tensor(space_mat), | |
'time_matrix': convert_tensor(time_mat), | |
'time_windows': convert_tensor(windows), | |
'depot': 0, | |
'num_vehicles': distance_matrix[i].shape[1] # assuming square matrix | |
} | |
converted_data.append(sample_dict) | |
return converted_data | |