Spaces:
Runtime error
Runtime error
import torch | |
class Fleet(object): | |
def __init__(self, fleet_data, num_nodes, device='cpu'): | |
self.device = device | |
self.num_nodes = num_nodes | |
# Static fields | |
self.start_time = fleet_data['start_time'].to(device) | |
self.car_start_node = fleet_data['car_start_node'].to(device) | |
self.batch_size = self.start_time.shape[0] | |
self.num_cars = self.start_time.shape[1] | |
# Depot assignment per car (long for indexing) | |
self.depot = self.car_start_node.view(self.batch_size, self.num_cars).long() | |
# Count distinct depots per batch | |
node_indices = torch.arange(self.num_nodes, device=self.device).view(1, 1, -1).repeat(self.batch_size, self.num_cars, 1) | |
depot_expanded = self.depot.view(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes) | |
self.num_depots = ((node_indices == depot_expanded).float().sum(dim=1) > 0).float().sum(dim=1).long() | |
# Dynamic fields | |
self.time = self.start_time.clone() | |
self.distance = torch.zeros(self.batch_size, self.num_cars, 1, device=self.device) | |
self.late_time = torch.zeros(self.batch_size, self.num_cars, 1, device=self.device) | |
# Path and arrival tracking | |
self.path = self.depot.unsqueeze(2) # (B, C, 1) | |
self.arrival_times = self.time.clone() | |
# Current location | |
self.node = self.depot.clone() | |
# Node visitation mask | |
self.traversed_nodes = self.initialize_traversed_nodes() | |
# Termination flag | |
self.finished = torch.zeros(self.batch_size, self.num_cars, device=self.device) | |
def initialize_traversed_nodes(self): | |
""" | |
Initializes a boolean tensor indicating whether each node has been visited by each car. | |
Initially, only the depot node is visited. | |
""" | |
node_indices = torch.arange(self.num_nodes, device=self.device).view(1, -1, 1).repeat(self.batch_size, 1, self.num_cars).float() | |
depot_indices = self.depot.view(self.batch_size, 1, self.num_cars).repeat(1, self.num_nodes, 1).float() | |
visited = ((node_indices == depot_indices).float().sum(dim=2) > 0) # (B, N) | |
return visited | |
def construct_vector(self): | |
""" | |
Constructs the input vector used in the decoder (e.g., current time for each car). | |
""" | |
return self.time.view(self.batch_size, self.num_cars, 1) | |