a-ragab-h-m's picture
Update Actor/fleet.py
b71417f verified
import torch
class Fleet(object):
def __init__(self, fleet_data, num_nodes, device='cpu'):
self.device = device
self.num_nodes = num_nodes
# Static fields
self.start_time = fleet_data['start_time'].to(device)
self.car_start_node = fleet_data['car_start_node'].to(device)
self.batch_size = self.start_time.shape[0]
self.num_cars = self.start_time.shape[1]
# Depot assignment per car (long for indexing)
self.depot = self.car_start_node.view(self.batch_size, self.num_cars).long()
# Count distinct depots per batch
node_indices = torch.arange(self.num_nodes, device=self.device).view(1, 1, -1).repeat(self.batch_size, self.num_cars, 1)
depot_expanded = self.depot.view(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes)
self.num_depots = ((node_indices == depot_expanded).float().sum(dim=1) > 0).float().sum(dim=1).long()
# Dynamic fields
self.time = self.start_time.clone()
self.distance = torch.zeros(self.batch_size, self.num_cars, 1, device=self.device)
self.late_time = torch.zeros(self.batch_size, self.num_cars, 1, device=self.device)
# Path and arrival tracking
self.path = self.depot.unsqueeze(2) # (B, C, 1)
self.arrival_times = self.time.clone()
# Current location
self.node = self.depot.clone()
# Node visitation mask
self.traversed_nodes = self.initialize_traversed_nodes()
# Termination flag
self.finished = torch.zeros(self.batch_size, self.num_cars, device=self.device)
def initialize_traversed_nodes(self):
"""
Initializes a boolean tensor indicating whether each node has been visited by each car.
Initially, only the depot node is visited.
"""
node_indices = torch.arange(self.num_nodes, device=self.device).view(1, -1, 1).repeat(self.batch_size, 1, self.num_cars).float()
depot_indices = self.depot.view(self.batch_size, 1, self.num_cars).repeat(1, self.num_nodes, 1).float()
visited = ((node_indices == depot_indices).float().sum(dim=2) > 0) # (B, N)
return visited
def construct_vector(self):
"""
Constructs the input vector used in the decoder (e.g., current time for each car).
"""
return self.time.view(self.batch_size, self.num_cars, 1)