|
import torch |
|
|
|
|
|
class Fleet(object): |
|
|
|
def __init__(self, fleet_data, num_nodes, device='cpu'): |
|
self.device = device |
|
self.num_nodes = num_nodes |
|
|
|
|
|
self.start_time = fleet_data['start_time'].to(device) |
|
self.car_start_node = fleet_data['car_start_node'].to(device) |
|
|
|
self.batch_size = self.start_time.shape[0] |
|
self.num_cars = self.start_time.shape[1] |
|
|
|
|
|
self.depot = self.car_start_node.view(self.batch_size, self.num_cars).long() |
|
|
|
|
|
node_indices = torch.arange(self.num_nodes, device=self.device).view(1, 1, -1).repeat(self.batch_size, self.num_cars, 1) |
|
depot_expanded = self.depot.view(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes) |
|
self.num_depots = ((node_indices == depot_expanded).float().sum(dim=1) > 0).float().sum(dim=1).long() |
|
|
|
|
|
self.time = self.start_time.clone() |
|
self.distance = torch.zeros(self.batch_size, self.num_cars, 1, device=self.device) |
|
self.late_time = torch.zeros(self.batch_size, self.num_cars, 1, device=self.device) |
|
|
|
|
|
self.path = self.depot.unsqueeze(2) |
|
self.arrival_times = self.time.clone() |
|
|
|
|
|
self.node = self.depot.clone() |
|
|
|
|
|
self.traversed_nodes = self.initialize_traversed_nodes() |
|
|
|
|
|
self.finished = torch.zeros(self.batch_size, self.num_cars, device=self.device) |
|
|
|
def initialize_traversed_nodes(self): |
|
""" |
|
Initializes a boolean tensor indicating whether each node has been visited by each car. |
|
Initially, only the depot node is visited. |
|
""" |
|
node_indices = torch.arange(self.num_nodes, device=self.device).view(1, -1, 1).repeat(self.batch_size, 1, self.num_cars).float() |
|
depot_indices = self.depot.view(self.batch_size, 1, self.num_cars).repeat(1, self.num_nodes, 1).float() |
|
visited = ((node_indices == depot_indices).float().sum(dim=2) > 0) |
|
return visited |
|
|
|
def construct_vector(self): |
|
""" |
|
Constructs the input vector used in the decoder (e.g., current time for each car). |
|
""" |
|
return self.time.view(self.batch_size, self.num_cars, 1) |
|
|