a-ragab-h-m commited on
Commit
b233689
·
verified ·
1 Parent(s): d5c00b7

Update google_solver/convert_data.py

Browse files
Files changed (1) hide show
  1. google_solver/convert_data.py +50 -30
google_solver/convert_data.py CHANGED
@@ -1,57 +1,77 @@
1
-
2
  import torch
3
  import numpy as np
4
 
5
 
6
  def convert_tensor(x):
7
- x = x.long().numpy().astype('int')
 
8
 
9
- if len(x.shape) == 1:
10
- return list(x)
11
- else:
12
- return [list(x[i]) for i in range(x.shape[0])]
13
 
 
 
 
 
 
 
 
14
 
15
 
16
  def make_time_windows(start_time, end_time):
 
 
 
 
 
 
 
 
 
 
17
  return torch.cat([start_time, end_time], dim=2)
18
 
19
 
 
 
 
20
 
21
- def convert_data(input, scale_factor):
 
 
22
 
23
- graph_data, fleet_data = input
 
 
 
 
 
 
 
 
24
 
25
  start_times = graph_data['start_times']
26
  end_times = graph_data['end_times']
27
-
28
  distance_matrix = graph_data['distance_matrix']
29
  time_matrix = graph_data['time_matrix']
30
-
31
  time_windows = make_time_windows(start_times, end_times)
32
 
 
 
33
 
34
- batch_size = distance_matrix.shape[0]
35
- data = []
36
  for i in range(batch_size):
 
 
 
37
 
38
- num_vehicles = distance_matrix[i].shape[1]
39
-
40
- space_mat = distance_matrix[i] * scale_factor
41
- time_mat = time_matrix[i] * scale_factor
42
- windows = time_windows[i] * scale_factor
43
-
44
- space_mat = convert_tensor(space_mat)
45
- time_mat = convert_tensor(time_mat)
46
- windows = convert_tensor(windows)
47
-
48
 
49
- D = {'distance_matrix': space_mat,
50
- 'time_matrix': time_mat,
51
- 'time_windows': windows,
52
- 'depot': 0,
53
- 'num_vehicles': num_vehicles
54
- }
55
 
56
- data.append(D)
57
- return data
 
 
1
  import torch
2
  import numpy as np
3
 
4
 
5
  def convert_tensor(x):
6
+ """
7
+ Convert a PyTorch tensor to a nested Python list of integers.
8
 
9
+ Args:
10
+ x (torch.Tensor): Tensor to convert.
 
 
11
 
12
+ Returns:
13
+ list: Converted nested list.
14
+ """
15
+ x = x.long().cpu().numpy().astype(int)
16
+ if x.ndim == 1:
17
+ return list(x)
18
+ return [list(row) for row in x]
19
 
20
 
21
  def make_time_windows(start_time, end_time):
22
+ """
23
+ Concatenate start and end time tensors to form time windows.
24
+
25
+ Args:
26
+ start_time (torch.Tensor): Start times (B, N, 1)
27
+ end_time (torch.Tensor): End times (B, N, 1)
28
+
29
+ Returns:
30
+ torch.Tensor: Time windows (B, N, 2)
31
+ """
32
  return torch.cat([start_time, end_time], dim=2)
33
 
34
 
35
+ def convert_data(input_data, scale_factor):
36
+ """
37
+ Convert batched graph and fleet data to OR-Tools compatible format.
38
 
39
+ Args:
40
+ input_data (tuple): Tuple of (graph_data, fleet_data) as dictionaries.
41
+ scale_factor (float): Scaling factor to convert float to integer.
42
 
43
+ Returns:
44
+ list: List of dictionaries, one per batch item, containing:
45
+ - distance_matrix
46
+ - time_matrix
47
+ - time_windows
48
+ - depot index (default 0)
49
+ - num_vehicles
50
+ """
51
+ graph_data, fleet_data = input_data
52
 
53
  start_times = graph_data['start_times']
54
  end_times = graph_data['end_times']
 
55
  distance_matrix = graph_data['distance_matrix']
56
  time_matrix = graph_data['time_matrix']
 
57
  time_windows = make_time_windows(start_times, end_times)
58
 
59
+ batch_size = distance_matrix.size(0)
60
+ converted_data = []
61
 
 
 
62
  for i in range(batch_size):
63
+ space_mat = (distance_matrix[i] * scale_factor)
64
+ time_mat = (time_matrix[i] * scale_factor)
65
+ windows = (time_windows[i] * scale_factor)
66
 
67
+ sample_dict = {
68
+ 'distance_matrix': convert_tensor(space_mat),
69
+ 'time_matrix': convert_tensor(time_mat),
70
+ 'time_windows': convert_tensor(windows),
71
+ 'depot': 0,
72
+ 'num_vehicles': distance_matrix[i].shape[1] # assuming square matrix
73
+ }
 
 
 
74
 
75
+ converted_data.append(sample_dict)
 
 
 
 
 
76
 
77
+ return converted_data