import torch class Fleet(object): def __init__(self, fleet_data, num_nodes, device='cpu'): self.device = device self.num_nodes = num_nodes # Static fields self.start_time = fleet_data['start_time'].to(device) self.car_start_node = fleet_data['car_start_node'].to(device) self.batch_size = self.start_time.shape[0] self.num_cars = self.start_time.shape[1] # Depot assignment per car (long for indexing) self.depot = self.car_start_node.view(self.batch_size, self.num_cars).long() # Count distinct depots per batch node_indices = torch.arange(self.num_nodes, device=self.device).view(1, 1, -1).repeat(self.batch_size, self.num_cars, 1) depot_expanded = self.depot.view(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes) self.num_depots = ((node_indices == depot_expanded).float().sum(dim=1) > 0).float().sum(dim=1).long() # Dynamic fields self.time = self.start_time.clone() self.distance = torch.zeros(self.batch_size, self.num_cars, 1, device=self.device) self.late_time = torch.zeros(self.batch_size, self.num_cars, 1, device=self.device) # Path and arrival tracking self.path = self.depot.unsqueeze(2) # (B, C, 1) self.arrival_times = self.time.clone() # Current location self.node = self.depot.clone() # Node visitation mask self.traversed_nodes = self.initialize_traversed_nodes() # Termination flag self.finished = torch.zeros(self.batch_size, self.num_cars, device=self.device) def initialize_traversed_nodes(self): """ Initializes a boolean tensor indicating whether each node has been visited by each car. Initially, only the depot node is visited. """ node_indices = torch.arange(self.num_nodes, device=self.device).view(1, -1, 1).repeat(self.batch_size, 1, self.num_cars).float() depot_indices = self.depot.view(self.batch_size, 1, self.num_cars).repeat(1, self.num_nodes, 1).float() visited = ((node_indices == depot_indices).float().sum(dim=2) > 0) # (B, N) return visited def construct_vector(self): """ Constructs the input vector used in the decoder (e.g., current time for each car). """ return self.time.view(self.batch_size, self.num_cars, 1)