Update dataloader.py
Browse files- dataloader.py +14 -21
dataloader.py
CHANGED
@@ -1,46 +1,39 @@
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
from torch.utils.data import Dataset
|
4 |
-
|
5 |
import numpy as np
|
6 |
-
from random import shuffle
|
7 |
import os
|
8 |
-
import pandas as pd
|
9 |
|
10 |
-
class VRP_Dataset(Dataset):
|
11 |
|
|
|
12 |
def __init__(self, dataset_size, num_nodes, num_depots, dataset_path, device='cpu', *args, **kwargs):
|
13 |
super().__init__()
|
14 |
-
|
15 |
self.device = device
|
16 |
self.dataset_size = dataset_size
|
17 |
self.num_nodes = num_nodes
|
18 |
self.num_depots = num_depots
|
19 |
|
20 |
-
# Load
|
21 |
raw_data = pd.read_csv(dataset_path)
|
22 |
-
if len(raw_data) < dataset_size:
|
23 |
-
raise ValueError("
|
24 |
-
|
25 |
-
sampled_data = raw_data.sample(n=dataset_size, random_state=42).reset_index(drop=True)
|
26 |
-
|
27 |
-
# Extract coordinates (assuming columns named 'longitude', 'latitude')
|
28 |
-
coords = torch.tensor(sampled_data[['longitude', 'latitude']].values, dtype=torch.float32)
|
29 |
|
30 |
-
#
|
|
|
31 |
node_positions = coords.view(dataset_size, num_nodes, 2)
|
32 |
self.node_positions = node_positions
|
33 |
|
34 |
-
#
|
35 |
num_cars = num_nodes
|
36 |
launch_time = torch.zeros(dataset_size, num_cars, 1)
|
37 |
car_start_node = torch.randint(low=0, high=num_depots, size=(dataset_size, num_cars, 1))
|
38 |
self.fleet_data = {
|
39 |
'start_time': launch_time,
|
40 |
-
'car_start_node': car_start_node
|
41 |
}
|
42 |
|
43 |
-
#
|
44 |
a = torch.arange(num_nodes).reshape(1, 1, -1).repeat(dataset_size, num_cars, 1)
|
45 |
b = car_start_node.repeat(1, 1, num_nodes)
|
46 |
depot = ((a == b).sum(dim=1) > 0).float().unsqueeze(2)
|
@@ -63,13 +56,13 @@ class VRP_Dataset(Dataset):
|
|
63 |
def compute_distance_matrix(self, node_positions):
|
64 |
x = node_positions.unsqueeze(1).repeat(1, self.num_nodes, 1, 1)
|
65 |
y = node_positions.unsqueeze(2).repeat(1, 1, self.num_nodes, 1)
|
66 |
-
distance = (((x - y) ** 2).sum(dim=3))
|
67 |
return distance
|
68 |
|
69 |
def __getitem__(self, idx):
|
70 |
-
|
71 |
-
|
72 |
-
return
|
73 |
|
74 |
def __len__(self):
|
75 |
return self.dataset_size
|
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
from torch.utils.data import Dataset
|
4 |
+
import pandas as pd
|
5 |
import numpy as np
|
|
|
6 |
import os
|
|
|
7 |
|
|
|
8 |
|
9 |
+
class VRP_Dataset(Dataset):
|
10 |
def __init__(self, dataset_size, num_nodes, num_depots, dataset_path, device='cpu', *args, **kwargs):
|
11 |
super().__init__()
|
|
|
12 |
self.device = device
|
13 |
self.dataset_size = dataset_size
|
14 |
self.num_nodes = num_nodes
|
15 |
self.num_depots = num_depots
|
16 |
|
17 |
+
# Load CSV data
|
18 |
raw_data = pd.read_csv(dataset_path)
|
19 |
+
if len(raw_data) < dataset_size * num_nodes:
|
20 |
+
raise ValueError("Not enough rows in CSV to build required dataset")
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
+
# Randomly sample and reshape
|
23 |
+
coords = torch.tensor(raw_data[['longitude', 'latitude']].values[:dataset_size * num_nodes], dtype=torch.float32)
|
24 |
node_positions = coords.view(dataset_size, num_nodes, 2)
|
25 |
self.node_positions = node_positions
|
26 |
|
27 |
+
# Fleet data
|
28 |
num_cars = num_nodes
|
29 |
launch_time = torch.zeros(dataset_size, num_cars, 1)
|
30 |
car_start_node = torch.randint(low=0, high=num_depots, size=(dataset_size, num_cars, 1))
|
31 |
self.fleet_data = {
|
32 |
'start_time': launch_time,
|
33 |
+
'car_start_node': car_start_node,
|
34 |
}
|
35 |
|
36 |
+
# Graph data
|
37 |
a = torch.arange(num_nodes).reshape(1, 1, -1).repeat(dataset_size, num_cars, 1)
|
38 |
b = car_start_node.repeat(1, 1, num_nodes)
|
39 |
depot = ((a == b).sum(dim=1) > 0).float().unsqueeze(2)
|
|
|
56 |
def compute_distance_matrix(self, node_positions):
|
57 |
x = node_positions.unsqueeze(1).repeat(1, self.num_nodes, 1, 1)
|
58 |
y = node_positions.unsqueeze(2).repeat(1, 1, self.num_nodes, 1)
|
59 |
+
distance = torch.sqrt(((x - y) ** 2).sum(dim=3))
|
60 |
return distance
|
61 |
|
62 |
def __getitem__(self, idx):
|
63 |
+
graph = {key: self.graph_data[key][idx].unsqueeze(0).to(self.device) for key in self.graph_data}
|
64 |
+
fleet = {key: self.fleet_data[key][idx].unsqueeze(0).to(self.device) for key in self.fleet_data}
|
65 |
+
return graph, fleet
|
66 |
|
67 |
def __len__(self):
|
68 |
return self.dataset_size
|