a-ragab-h-m commited on
Commit
78f21d4
·
verified ·
1 Parent(s): 92e0df2

Upload 3 files

Browse files
utils/beam_search_utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def widen_tensor(datum, factor):
5
+
6
+ if len(datum.shape) == 0:
7
+ return datum
8
+
9
+ L = list(datum.shape)
10
+ a = [1, factor] + [1 for _ in range(len(L) - 1)]
11
+
12
+ datum = datum.unsqueeze(1)
13
+ datum = datum.repeat(*a)
14
+
15
+ if len(L) > 1:
16
+ b = [L[0] * factor] + L[1:]
17
+ else:
18
+ b = [L[0] * factor]
19
+
20
+ datum = datum.reshape(*b)
21
+ return datum
22
+
23
+
24
+ def widen_data(actor, include_embeddings=True, include_projections=True):
25
+
26
+ F = dir(actor.fleet)
27
+ for s in F:
28
+ x = getattr(actor.fleet, s)
29
+ if isinstance(x, torch.Tensor):
30
+ if len(x.shape) > 0:
31
+ y = widen_tensor(x, factor=actor.sample_size)
32
+ setattr(actor.fleet, s, y)
33
+
34
+ G = dir(actor.graph)
35
+ for s in G:
36
+ x = getattr(actor.graph, s)
37
+ if isinstance(x, torch.Tensor):
38
+ if len(x.shape) > 0:
39
+ y = widen_tensor(x, factor=actor.sample_size)
40
+ setattr(actor.graph, s, y)
41
+
42
+ actor.log_probs = widen_tensor(actor.log_probs, factor=actor.sample_size)
43
+
44
+ if include_embeddings:
45
+ actor.node_embeddings = widen_tensor(actor.node_embeddings, factor=actor.sample_size)
46
+
47
+ if include_projections:
48
+ def widen_projection(x, size):
49
+ if len(x.shape) > 3:
50
+ y = x.unsqueeze(2).repeat(1, 1, size, 1, 1)
51
+ return y.reshape(x.shape[0], x.shape[1] * size, x.shape[2], x.shape[3])
52
+ else:
53
+ return widen_tensor(x, size)
54
+
55
+ actor.node_projections = {key : widen_projection(actor.node_projections[key], actor.sample_size)
56
+ for key in actor.node_projections}
57
+
58
+
59
+
60
+ def select_data(self, index, include_embeddings=True, include_projections=True):
61
+ m = index.max().item()
62
+
63
+ F = dir(self.fleet)
64
+ for s in F:
65
+ x = getattr(self.fleet, s)
66
+ if isinstance(x, torch.Tensor):
67
+ if (len(x.shape) > 0) and (x.shape[0] >= m):
68
+ setattr(self.fleet, s, x[index])
69
+
70
+ G = dir(self.graph)
71
+ for s in G:
72
+ x = getattr(self.graph, s)
73
+ if isinstance(x, torch.Tensor):
74
+ if (len(x.shape) > 0) and (x.shape[0] >= m):
75
+ setattr(self.graph, s, x[index])
76
+
77
+ self.log_probs = self.log_probs[index]
78
+
79
+ if include_embeddings:
80
+ self.node_embeddings = self.node_embeddings[index]
81
+
82
+ if include_projections:
83
+ def select_projection(x, index):
84
+ if len(x.shape) > 3:
85
+ return x[:,index,:,:]
86
+ else:
87
+ return x[index]
88
+
89
+ self.node_projections = {key : select_projection(self.node_projections[key], index)
90
+ for key in self.node_projections}
91
+
utils/build_dataset.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from fleet_beam_search_3.dataloader import VRP_Dataset
2
+
3
+
4
+ VRP_Dataset(dataset_size=1000)
utils/gradient_clipping.py ADDED
File without changes