File size: 2,406 Bytes
a3d45a8
 
 
 
 
 
 
 
 
b71417f
 
 
a3d45a8
 
 
 
b71417f
 
a3d45a8
b71417f
 
 
 
a3d45a8
b71417f
 
 
 
a3d45a8
b71417f
 
 
a3d45a8
b71417f
 
a3d45a8
b71417f
a3d45a8
 
b71417f
 
a3d45a8
 
b71417f
 
 
 
 
 
 
 
a3d45a8
 
b71417f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch


class Fleet(object):

    def __init__(self, fleet_data, num_nodes, device='cpu'):
        self.device = device
        self.num_nodes = num_nodes

        # Static fields
        self.start_time = fleet_data['start_time'].to(device)
        self.car_start_node = fleet_data['car_start_node'].to(device)

        self.batch_size = self.start_time.shape[0]
        self.num_cars = self.start_time.shape[1]

        # Depot assignment per car (long for indexing)
        self.depot = self.car_start_node.view(self.batch_size, self.num_cars).long()

        # Count distinct depots per batch
        node_indices = torch.arange(self.num_nodes, device=self.device).view(1, 1, -1).repeat(self.batch_size, self.num_cars, 1)
        depot_expanded = self.depot.view(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes)
        self.num_depots = ((node_indices == depot_expanded).float().sum(dim=1) > 0).float().sum(dim=1).long()

        # Dynamic fields
        self.time = self.start_time.clone()
        self.distance = torch.zeros(self.batch_size, self.num_cars, 1, device=self.device)
        self.late_time = torch.zeros(self.batch_size, self.num_cars, 1, device=self.device)

        # Path and arrival tracking
        self.path = self.depot.unsqueeze(2)  # (B, C, 1)
        self.arrival_times = self.time.clone()

        # Current location
        self.node = self.depot.clone()

        # Node visitation mask
        self.traversed_nodes = self.initialize_traversed_nodes()

        # Termination flag
        self.finished = torch.zeros(self.batch_size, self.num_cars, device=self.device)

    def initialize_traversed_nodes(self):
        """
        Initializes a boolean tensor indicating whether each node has been visited by each car.
        Initially, only the depot node is visited.
        """
        node_indices = torch.arange(self.num_nodes, device=self.device).view(1, -1, 1).repeat(self.batch_size, 1, self.num_cars).float()
        depot_indices = self.depot.view(self.batch_size, 1, self.num_cars).repeat(1, self.num_nodes, 1).float()
        visited = ((node_indices == depot_indices).float().sum(dim=2) > 0)  # (B, N)
        return visited

    def construct_vector(self):
        """
        Constructs the input vector used in the decoder (e.g., current time for each car).
        """
        return self.time.view(self.batch_size, self.num_cars, 1)