sophiat44 commited on
Commit
5a87d8d
·
1 Parent(s): 6501779

model upload

Browse files
Files changed (43) hide show
  1. branchsbm/.DS_Store +0 -0
  2. branchsbm/branch_flow_net_train.py +348 -0
  3. branchsbm/branch_growth_net_train.py +514 -0
  4. branchsbm/branch_interpolant_train.py +398 -0
  5. branchsbm/branchsbm.py +109 -0
  6. branchsbm/ema.py +64 -0
  7. configs/.DS_Store +0 -0
  8. configs/experiment/cell_single_branch.yaml +12 -0
  9. configs/experiment/clonidine_100D.yaml +22 -0
  10. configs/experiment/clonidine_150D.yaml +22 -0
  11. configs/experiment/clonidine_50D.yaml +22 -0
  12. configs/experiment/clonidine_50Dsingle.yaml +22 -0
  13. configs/experiment/lidar.yaml +14 -0
  14. configs/experiment/lidar_single.yaml +14 -0
  15. configs/experiment/mouse.yaml +17 -0
  16. configs/experiment/trametinib.yaml +22 -0
  17. configs/experiment/trametinib_single.yaml +22 -0
  18. dataloaders/.DS_Store +0 -0
  19. dataloaders/clonidine_data.py +269 -0
  20. dataloaders/clonidine_single_branch.py +274 -0
  21. dataloaders/clonidine_v2_data.py +287 -0
  22. dataloaders/lidar_data.py +532 -0
  23. dataloaders/lidar_data_single.py +282 -0
  24. dataloaders/mouse_data.py +438 -0
  25. dataloaders/three_branch_data.py +310 -0
  26. dataloaders/trametinib_single.py +279 -0
  27. losses/.DS_Store +0 -0
  28. losses/energy_loss.py +73 -0
  29. networks/.DS_Store +0 -0
  30. networks/flow_mlp.py +18 -0
  31. networks/growth_mlp.py +37 -0
  32. networks/interpolant_mlp.py +35 -0
  33. networks/mlp_base.py +46 -0
  34. networks/utils.py +13 -0
  35. state_costs/.DS_Store +0 -0
  36. state_costs/land.py +26 -0
  37. state_costs/metric_factory.py +105 -0
  38. state_costs/rbf.py +156 -0
  39. train/.DS_Store +0 -0
  40. train/main_branches.py +342 -0
  41. train/parsers.py +419 -0
  42. train/train_utils.py +154 -0
  43. utils.py +198 -0
branchsbm/.DS_Store ADDED
Binary file (6.15 kB). View file
 
branchsbm/branch_flow_net_train.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append("./BranchSBM")
4
+ import torch
5
+ import wandb
6
+ import matplotlib.pyplot as plt
7
+ import pytorch_lightning as pl
8
+ from torch.optim import AdamW
9
+ from torchmetrics.functional import mean_squared_error
10
+ from torchdyn.core import NeuralODE
11
+ from networks.utils import flow_model_torch_wrapper
12
+ from utils import wasserstein_distance, plot_lidar
13
+ from branchsbm.ema import EMA
14
+
15
+ class BranchFlowNetTrainBase(pl.LightningModule):
16
+ def __init__(
17
+ self,
18
+ flow_matcher,
19
+ flow_nets,
20
+ skipped_time_points=None,
21
+ ot_sampler=None,
22
+ args=None,
23
+ ):
24
+ super().__init__()
25
+ self.args = args
26
+
27
+ self.flow_matcher = flow_matcher
28
+ self.flow_nets = flow_nets # list of flow networks for each branch
29
+ self.ot_sampler = ot_sampler
30
+ self.skipped_time_points = skipped_time_points
31
+
32
+ self.optimizer_name = args.flow_optimizer
33
+ self.lr = args.flow_lr
34
+ self.weight_decay = args.flow_weight_decay
35
+ self.whiten = args.whiten
36
+ self.working_dir = args.working_dir
37
+
38
+ #branching
39
+ self.branches = len(flow_nets)
40
+
41
+ def forward(self, t, xt, branch_idx):
42
+ # output velocity given branch_idx
43
+ return self.flow_nets[branch_idx](t, xt)
44
+
45
+ def _compute_loss(self, main_batch):
46
+
47
+ x0s = [main_batch["x0"][0]]
48
+ w0s = [main_batch["x0"][1]]
49
+
50
+ x1s_list = []
51
+ w1s_list = []
52
+
53
+ if self.branches > 1:
54
+ for i in range(self.branches):
55
+ x1s_list.append([main_batch[f"x1_{i+1}"][0]])
56
+ w1s_list.append([main_batch[f"x1_{i+1}"][1]])
57
+ else:
58
+ x1s_list.append([main_batch["x1"][0]])
59
+ w1s_list.append([main_batch["x1"][1]])
60
+
61
+ assert len(x1s_list) == self.branches, "Mismatch between x1s_list and expected branches"
62
+
63
+ loss = 0
64
+ for branch_idx in range(self.branches):
65
+ ts, xts, uts = self._process_flow(x0s, x1s_list[branch_idx], branch_idx)
66
+
67
+ t = torch.cat(ts)
68
+ xt = torch.cat(xts)
69
+ ut = torch.cat(uts)
70
+ vt = self(t[:, None], xt, branch_idx)
71
+
72
+ loss += mean_squared_error(vt, ut)
73
+
74
+ return loss
75
+
76
+ def _process_flow(self, x0s, x1s, branch_idx):
77
+ ts, xts, uts = [], [], []
78
+ t_start = self.timesteps[0]
79
+
80
+ for i, (x0, x1) in enumerate(zip(x0s, x1s)):
81
+
82
+ x0, x1 = torch.squeeze(x0), torch.squeeze(x1)
83
+
84
+ if self.ot_sampler is not None:
85
+ x0, x1 = self.ot_sampler.sample_plan(
86
+ x0,
87
+ x1,
88
+ replace=True,
89
+ )
90
+ if self.skipped_time_points and i + 1 >= self.skipped_time_points[0]:
91
+ t_start_next = self.timesteps[i + 2]
92
+ else:
93
+ t_start_next = self.timesteps[i + 1]
94
+
95
+ # edit to sample from correct flow matcher
96
+ t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(
97
+ x0, x1, t_start, t_start_next, branch_idx
98
+ )
99
+
100
+ ts.append(t)
101
+
102
+ xts.append(xt)
103
+ uts.append(ut)
104
+ t_start = t_start_next
105
+ return ts, xts, uts
106
+
107
+ def training_step(self, batch, batch_idx):
108
+ if self.args.data_type in ["scrna", "tahoe"]:
109
+ main_batch = batch[0]["train_samples"][0]
110
+ else:
111
+ main_batch = batch["train_samples"][0]
112
+
113
+ print("Main batch length")
114
+ print(len(main_batch["x0"]))
115
+ self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist()
116
+ loss = self._compute_loss(main_batch)
117
+ if self.flow_matcher.alpha != 0:
118
+ self.log(
119
+ "FlowNet/mean_geopath_cfm",
120
+ (self.flow_matcher.geopath_net_output.abs().mean()),
121
+ on_step=False,
122
+ on_epoch=True,
123
+ prog_bar=True,
124
+ )
125
+
126
+ self.log(
127
+ "FlowNet/train_loss_cfm",
128
+ loss,
129
+ on_step=False,
130
+ on_epoch=True,
131
+ prog_bar=True,
132
+ logger=True,
133
+ )
134
+
135
+
136
+ return loss
137
+
138
+ def validation_step(self, batch, batch_idx):
139
+ if self.args.data_type in ["scrna", "tahoe"]:
140
+ main_batch = batch[0]["val_samples"][0]
141
+ else:
142
+ main_batch = batch["val_samples"][0]
143
+
144
+ self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist()
145
+ val_loss = self._compute_loss(main_batch)
146
+ self.log(
147
+ "FlowNet/val_loss_cfm",
148
+ val_loss,
149
+ on_step=False,
150
+ on_epoch=True,
151
+ prog_bar=True,
152
+ logger=True,
153
+ )
154
+ return val_loss
155
+
156
+ def optimizer_step(self, *args, **kwargs):
157
+ super().optimizer_step(*args, **kwargs)
158
+
159
+ for net in self.flow_nets:
160
+ if isinstance(net, EMA):
161
+ net.update_ema()
162
+
163
+ def configure_optimizers(self):
164
+ if self.optimizer_name == "adamw":
165
+ optimizer = AdamW(
166
+ self.parameters(),
167
+ lr=self.lr,
168
+ weight_decay=self.weight_decay,
169
+ )
170
+ elif self.optimizer_name == "adam":
171
+ optimizer = torch.optim.Adam(
172
+ self.parameters(),
173
+ lr=self.lr,
174
+ )
175
+
176
+ return optimizer
177
+
178
+
179
+ class FlowNetTrainTrajectory(BranchFlowNetTrainBase):
180
+ def test_step(self, batch, batch_idx):
181
+ data_type = self.args.data_type
182
+ node = NeuralODE(
183
+ flow_model_torch_wrapper(self.flow_nets),
184
+ solver="euler",
185
+ sensitivity="adjoint",
186
+ atol=1e-5,
187
+ rtol=1e-5,
188
+ )
189
+
190
+ t_exclude = self.skipped_time_points[0] if self.skipped_time_points else None
191
+ if t_exclude is not None:
192
+ traj = node.trajectory(
193
+ batch[t_exclude - 1],
194
+ t_span=torch.linspace(
195
+ self.timesteps[t_exclude - 1], self.timesteps[t_exclude], 101
196
+ ),
197
+ )
198
+ X_mid_pred = traj[-1]
199
+ traj = node.trajectory(
200
+ batch[t_exclude - 1],
201
+ t_span=torch.linspace(
202
+ self.timesteps[t_exclude - 1],
203
+ self.timesteps[t_exclude + 1],
204
+ 101,
205
+ ),
206
+ )
207
+
208
+ EMD = wasserstein_distance(X_mid_pred, batch[t_exclude], p=1)
209
+ self.final_EMD = EMD
210
+
211
+ self.log("test_EMD", EMD, on_step=False, on_epoch=True, prog_bar=True)
212
+
213
+ class FlowNetTrainCell(BranchFlowNetTrainBase):
214
+ def test_step(self, batch, batch_idx):
215
+ x0 = batch[0]["test_samples"][0]["x0"][0] # [B, D]
216
+ dataset_points = batch[0]["test_samples"][0]["dataset"][0] # full dataset, [N, D]
217
+ t_span = torch.linspace(0, 1, 101)
218
+
219
+ all_trajs = []
220
+
221
+ for i, flow_net in enumerate(self.flow_nets):
222
+ node = NeuralODE(
223
+ flow_model_torch_wrapper(flow_net),
224
+ solver="euler",
225
+ sensitivity="adjoint",
226
+ )
227
+
228
+ with torch.no_grad():
229
+ traj = node.trajectory(x0, t_span).cpu() # [T, B, D]
230
+
231
+ if self.whiten:
232
+ traj_shape = traj.shape
233
+ traj = traj.reshape(-1, traj.shape[-1])
234
+ traj = self.trainer.datamodule.scaler.inverse_transform(
235
+ traj.cpu().detach().numpy()
236
+ ).reshape(traj_shape)
237
+ dataset_points = self.trainer.datamodule.scaler.inverse_transform(
238
+ dataset_points.cpu().detach().numpy()
239
+ )
240
+
241
+ traj = torch.tensor(traj)
242
+ traj = torch.transpose(traj, 0, 1) # [B, T, D]
243
+ all_trajs.append(traj)
244
+
245
+ dataset_2d = dataset_points[:, :2] if isinstance(dataset_points, torch.Tensor) else dataset_points[:, :2]
246
+
247
+ # ===== Plot all 2D trajectories together with dataset and start/end points =====
248
+ fig, ax = plt.subplots(figsize=(6, 5))
249
+ dataset_2d = dataset_2d.cpu().numpy()
250
+ ax.scatter(dataset_2d[:, 0], dataset_2d[:, 1], c="gray", s=1, alpha=0.5, label="Dataset", zorder=1)
251
+ for traj in all_trajs:
252
+ traj_2d = traj[..., :2] # [B, T, 2]
253
+ for i in range(traj_2d.shape[0]):
254
+ ax.plot(traj_2d[i, :, 0], traj_2d[i, :, 1], alpha=0.8, zorder=2)
255
+ ax.scatter(traj_2d[i, 0, 0], traj_2d[i, 0, 1], c='green', s=10, label="t=0" if i == 0 else "", zorder=3)
256
+ ax.scatter(traj_2d[i, -1, 0], traj_2d[i, -1, 1], c='red', s=10, label="t=1" if i == 0 else "", zorder=3)
257
+
258
+ ax.set_title("All Branch Trajectories (2D) with Dataset")
259
+ ax.set_xlabel("x")
260
+ ax.set_ylabel("y")
261
+ plt.axis("equal")
262
+ handles, labels = ax.get_legend_handles_labels()
263
+ if labels:
264
+ ax.legend()
265
+
266
+ save_path = f'./figures/{self.args.data_name}'
267
+
268
+ os.makedirs(save_path, exist_ok=True)
269
+ plt.savefig(f'{save_path}/{self.args.data_name}_all_branches.png', dpi=300)
270
+ plt.close()
271
+
272
+ # ===== Plot each 2D trajectory separately with dataset and endpoints =====
273
+ for i, traj in enumerate(all_trajs):
274
+ traj_2d = traj[..., :2]
275
+ fig, ax = plt.subplots(figsize=(6, 5))
276
+ ax.scatter(dataset_2d[:, 0], dataset_2d[:, 1], c="gray", s=1, alpha=0.5, label="Dataset", zorder=1)
277
+ for j in range(traj_2d.shape[0]):
278
+ ax.plot(traj_2d[j, :, 0], traj_2d[j, :, 1], alpha=0.9, zorder=2)
279
+ ax.scatter(traj_2d[j, 0, 0], traj_2d[j, 0, 1], c='green', s=12, label="t=0" if j == 0 else "", zorder=3)
280
+ ax.scatter(traj_2d[j, -1, 0], traj_2d[j, -1, 1], c='red', s=12, label="t=1" if j == 0 else "", zorder=3)
281
+
282
+ ax.set_title(f"Branch {i + 1} Trajectories (2D) with Dataset")
283
+ ax.set_xlabel("x")
284
+ ax.set_ylabel("y")
285
+ plt.axis("equal")
286
+ handles, labels = ax.get_legend_handles_labels()
287
+ if labels:
288
+ ax.legend()
289
+ plt.savefig(f'{save_path}/{self.args.data_name}_branch_{i + 1}.png', dpi=300)
290
+ plt.close()
291
+
292
+ class FlowNetTrainLidar(BranchFlowNetTrainBase):
293
+ def test_step(self, batch, batch_idx):
294
+ main_batch = batch["test_samples"][0]
295
+ metric_batch = batch["metric_samples"][0]
296
+
297
+ x0 = main_batch["x0"][0] # [B, D]
298
+ cloud_points = main_batch["dataset"][0] # full dataset, [N, D]
299
+ t_span = torch.linspace(0, 1, 101)
300
+
301
+ all_trajs = []
302
+
303
+ for i, flow_net in enumerate(self.flow_nets):
304
+ node = NeuralODE(
305
+ flow_model_torch_wrapper(flow_net),
306
+ solver="euler",
307
+ sensitivity="adjoint",
308
+ )
309
+
310
+ with torch.no_grad():
311
+ traj = node.trajectory(x0, t_span).cpu() # [T, B, D]
312
+
313
+ if self.whiten:
314
+ traj_shape = traj.shape
315
+ traj = traj.reshape(-1, 3)
316
+ traj = self.trainer.datamodule.scaler.inverse_transform(
317
+ traj.cpu().detach().numpy()
318
+ ).reshape(traj_shape)
319
+
320
+ traj = torch.tensor(traj)
321
+ traj = torch.transpose(traj, 0, 1) # [B, T, D]
322
+ all_trajs.append(traj)
323
+
324
+ # Inverse-transform the point cloud once
325
+ if self.whiten:
326
+ cloud_points = torch.tensor(
327
+ self.trainer.datamodule.scaler.inverse_transform(
328
+ cloud_points.cpu().detach().numpy()
329
+ )
330
+ )
331
+
332
+ # ===== Plot all trajectories together =====
333
+ fig = plt.figure(figsize=(6, 5))
334
+ ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
335
+ ax.view_init(elev=30, azim=-115, roll=0)
336
+ for i, traj in enumerate(all_trajs):
337
+ plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
338
+ plt.savefig('./figures/lidar/lidar_all_branches.png', dpi=300)
339
+ plt.close()
340
+
341
+ # ===== Plot each trajectory separately =====
342
+ for i, traj in enumerate(all_trajs):
343
+ fig = plt.figure(figsize=(6, 5))
344
+ ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
345
+ ax.view_init(elev=30, azim=-115, roll=0)
346
+ plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
347
+ plt.savefig(f'./figures/lidar/lidar_branch_{i + 1}.png', dpi=300)
348
+ plt.close()
branchsbm/branch_growth_net_train.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append("./BranchSBM")
4
+ import torch
5
+ import wandb
6
+ import matplotlib.pyplot as plt
7
+ import pytorch_lightning as pl
8
+ from torch.optim import AdamW
9
+ from torchmetrics.functional import mean_squared_error
10
+ from torchdyn.core import NeuralODE
11
+ import numpy as np
12
+ import lpips
13
+ from networks.utils import flow_model_torch_wrapper
14
+ from utils import wasserstein_distance, plot_lidar
15
+ from branchsbm.ema import EMA
16
+ from torchdiffeq import odeint as odeint2
17
+ from losses.energy_loss import EnergySolver, ReconsLoss
18
+
19
+ class GrowthNetTrain(pl.LightningModule):
20
+ def __init__(
21
+ self,
22
+ flow_nets,
23
+ growth_nets,
24
+ skipped_time_points=None,
25
+ ot_sampler=None,
26
+ args=None,
27
+
28
+ state_cost=None,
29
+ data_manifold_metric=None,
30
+
31
+ joint = False
32
+ ):
33
+ super().__init__()
34
+ #self.save_hyperparameters()
35
+ self.flow_nets = flow_nets
36
+
37
+ if not joint:
38
+ for param in self.flow_nets.parameters():
39
+ param.requires_grad = False
40
+
41
+ self.growth_nets = growth_nets # list of growth networks for each branch
42
+
43
+ self.ot_sampler = ot_sampler
44
+ self.skipped_time_points = skipped_time_points
45
+
46
+ self.optimizer_name = args.growth_optimizer
47
+ self.lr = args.growth_lr
48
+ self.weight_decay = args.growth_weight_decay
49
+ self.whiten = args.whiten
50
+ self.working_dir = args.working_dir
51
+
52
+ self.args = args
53
+
54
+ #branching
55
+ self.state_cost = state_cost
56
+ self.data_manifold_metric = data_manifold_metric
57
+ self.branches = len(growth_nets)
58
+ self.metric_clusters = args.metric_clusters
59
+
60
+ self.recons_loss = ReconsLoss()
61
+
62
+ # loss weights
63
+ self.lambda_energy = args.lambda_energy
64
+ self.lambda_mass = args.lambda_mass
65
+ self.lambda_match = args.lambda_match
66
+ self.lambda_recons = args.lambda_recons
67
+
68
+ self.joint = joint
69
+
70
+ def forward(self, t, xt, branch_idx):
71
+ # output growth rate given branch_idx
72
+ return self.growth_nets[branch_idx](t, xt)
73
+
74
+ def _compute_loss(self, main_batch, metric_samples_batch=None, validation=False):
75
+ x0s = main_batch["x0"][0]
76
+ w0s = main_batch["x0"][1]
77
+ x1s_list = []
78
+ w1s_list = []
79
+
80
+ if self.branches > 1:
81
+ for i in range(self.branches):
82
+ x1s_list.append([main_batch[f"x1_{i+1}"][0]])
83
+ w1s_list.append([main_batch[f"x1_{i+1}"][1]])
84
+ else:
85
+ x1s_list.append([main_batch["x1"][0]])
86
+ w1s_list.append([main_batch["x1"][1]])
87
+
88
+ if self.args.manifold:
89
+ #changed
90
+ if self.metric_clusters == 4:
91
+ branch_sample_pairs = [
92
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
93
+ (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
94
+ (metric_samples_batch[0], metric_samples_batch[3]),
95
+ ]
96
+ elif self.metric_clusters == 3:
97
+ branch_sample_pairs = [
98
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
99
+ (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
100
+ ]
101
+ elif self.metric_clusters == 2 and self.branches == 2:
102
+ branch_sample_pairs = [
103
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
104
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_2 (branch 2)
105
+ ]
106
+ else:
107
+ branch_sample_pairs = [
108
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
109
+ ]
110
+
111
+ batch_size = x0s.shape[0]
112
+
113
+ assert len(x1s_list) == self.branches, "Mismatch between x1s_list and expected branches"
114
+
115
+ energy_loss = [0.] * self.branches
116
+ mass_loss = 0.
117
+ neg_weight_penalty = 0.
118
+ match_loss = [0.] * self.branches
119
+ recons_loss = [0.] * self.branches
120
+
121
+ dtype = x0s[0].dtype
122
+ #w0s = torch.zeros((batch_size, 1), dtype=dtype)
123
+ m0s = torch.zeros_like(w0s, dtype=dtype)
124
+ start_state = (x0s, w0s, m0s)
125
+
126
+ xt = [x0s.clone() for _ in range(self.branches)]
127
+ w0_branch = torch.zeros_like(w0s, dtype=dtype)
128
+ w0_branches = []
129
+ w0_branches.append(w0s)
130
+ for _ in range(self.branches - 1):
131
+ w0_branches.append(w0_branch)
132
+ #w0_branches = [w0_branch.clone() for _ in range(self.branches - 1)]
133
+ wt = w0_branches
134
+
135
+ mt = [m0s.clone() for _ in range(self.branches)]
136
+
137
+ # loop through timesteps
138
+ for s, t in zip(self.timesteps[:-1], self.timesteps[1:]):
139
+ time = torch.Tensor([s, t])
140
+
141
+ total_w_t = 0
142
+ # loop through branches
143
+ for i in range(self.branches):
144
+
145
+ if self.args.manifold:
146
+ start_samples, end_samples = branch_sample_pairs[i]
147
+ samples = torch.cat([start_samples, end_samples], dim=0)
148
+
149
+ # initialize weight and energy
150
+ start_state = (xt[i], wt[i], mt[i])
151
+
152
+ # loop over timesteps
153
+ xt_next, wt_next, mt_next = self.take_step(time, start_state, i, samples)
154
+
155
+ # placeholders for next state
156
+ xt_last = xt_next[-1]
157
+ wt_last = wt_next[-1]
158
+ mt_last = mt_next[-1]
159
+
160
+ total_w_t += wt_last
161
+
162
+ energy_loss[i] += (mt_last - mt[i])
163
+ neg_weight_penalty += torch.relu(-wt_last).sum()
164
+
165
+ # update branch state
166
+ xt[i] = xt_last.clone().detach()
167
+ wt[i] = wt_last.clone().detach()
168
+ mt[i] = mt_last.clone().detach()
169
+
170
+ # calculate mass loss from all branches
171
+ target = torch.ones_like(total_w_t)
172
+ mass_loss += mean_squared_error(total_w_t, target)
173
+
174
+ # calculate loss that matches final weights
175
+ for i in range(self.branches):
176
+ match_loss[i] = mean_squared_error(wt[i], w1s_list[i][0])
177
+ # compute reconstruction loss
178
+ recons_loss[i] = self.recons_loss(xt[i], x1s_list[i][0])
179
+
180
+ # average across times
181
+ mass_loss = mass_loss / len(self.timesteps)
182
+
183
+
184
+ # mean across branches
185
+ energy_loss = torch.mean(torch.stack(energy_loss))
186
+ match_loss = torch.mean(torch.stack(match_loss))
187
+ recons_loss = torch.mean(torch.stack(recons_loss))
188
+
189
+ loss = (self.lambda_energy * energy_loss) + (self.lambda_mass * (mass_loss + neg_weight_penalty)) + (self.lambda_match * match_loss) \
190
+ + (self.lambda_recons * recons_loss)
191
+
192
+ if self.joint:
193
+ if validation:
194
+ self.log("JointTrain/val_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True)
195
+ self.log("JointTrain/val_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True)
196
+ self.log("JointTrain/val_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True)
197
+ self.log("JointTrain/val_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True)
198
+ self.log("JointTrain/val_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True)
199
+ else:
200
+ self.log("JointTrain/train_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True)
201
+ self.log("JointTrain/train_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True)
202
+ self.log("JointTrain/train_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True)
203
+ self.log("JointTrain/train_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True)
204
+ self.log("JointTrain/train_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True)
205
+ else:
206
+ if validation:
207
+ self.log("GrowthNet/val_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True)
208
+ self.log("GrowthNet/val_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True)
209
+ self.log("GrowthNet/val_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True)
210
+ self.log("GrowthNet/val_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True)
211
+ self.log("GrowthNet/val_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True)
212
+ else:
213
+ self.log("GrowthNet/train_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True)
214
+ self.log("GrowthNet/train_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True)
215
+ self.log("GrowthNet/train_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True)
216
+ self.log("GrowthNet/train_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True)
217
+ self.log("GrowthNet/train_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True)
218
+
219
+ return loss
220
+
221
+ def take_step(self, t, start_state, branch_idx, samples=None):
222
+
223
+ flow_net = self.flow_nets[branch_idx]
224
+ growth_net = self.growth_nets[branch_idx]
225
+
226
+
227
+ x_t, w_t, m_t = odeint2(EnergySolver(flow_net, growth_net, self.state_cost, self.data_manifold_metric, samples), start_state, t, options=dict(step_size=0.1),method='euler')
228
+
229
+ return x_t, w_t, m_t
230
+
231
+ def training_step(self, batch, batch_idx):
232
+ if self.args.data_type in ["scrna", "tahoe"]:
233
+ main_batch = batch[0]["train_samples"][0]
234
+ metric_batch = batch[0]["metric_samples"][0]
235
+ else:
236
+ main_batch = batch["train_samples"][0]
237
+ metric_batch = batch["metric_samples"][0]
238
+
239
+ self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist()
240
+ loss = self._compute_loss(main_batch, metric_batch, validation=False)
241
+
242
+ if self.joint:
243
+ self.log(
244
+ "JointTrain/train_loss",
245
+ loss,
246
+ on_step=False,
247
+ on_epoch=True,
248
+ prog_bar=True,
249
+ logger=True,
250
+ )
251
+ else:
252
+ self.log(
253
+ "GrowthNet/train_loss",
254
+ loss,
255
+ on_step=False,
256
+ on_epoch=True,
257
+ prog_bar=True,
258
+ logger=True,
259
+ )
260
+
261
+ return loss
262
+
263
+ def validation_step(self, batch, batch_idx):
264
+ if self.args.data_type in ["scrna", "tahoe"]:
265
+ main_batch = batch[0]["val_samples"][0]
266
+ metric_batch = batch[0]["metric_samples"][0]
267
+ else:
268
+ main_batch = batch["val_samples"][0]
269
+ metric_batch = batch["metric_samples"][0]
270
+
271
+ self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist()
272
+ val_loss = self._compute_loss(main_batch, metric_batch, validation=True)
273
+
274
+ if self.joint:
275
+ self.log(
276
+ "JointTrain/val_loss",
277
+ val_loss,
278
+ on_step=False,
279
+ on_epoch=True,
280
+ prog_bar=True,
281
+ logger=True,
282
+ )
283
+ else:
284
+ self.log(
285
+ "GrowthNet/val_loss",
286
+ val_loss,
287
+ on_step=False,
288
+ on_epoch=True,
289
+ prog_bar=True,
290
+ logger=True,
291
+ )
292
+ return val_loss
293
+
294
+ def optimizer_step(self, *args, **kwargs):
295
+ super().optimizer_step(*args, **kwargs)
296
+ for net in self.growth_nets:
297
+ if isinstance(net, EMA):
298
+ net.update_ema()
299
+ if self.joint:
300
+ for net in self.flow_nets:
301
+ if isinstance(net, EMA):
302
+ net.update_ema()
303
+
304
+ def configure_optimizers(self):
305
+ params = []
306
+ for net in self.growth_nets:
307
+ params += list(net.parameters())
308
+
309
+ if self.joint:
310
+ for net in self.flow_nets:
311
+ params += list(net.parameters())
312
+
313
+ if self.optimizer_name == "adamw":
314
+ optimizer = AdamW(
315
+ params,
316
+ lr=self.lr,
317
+ weight_decay=self.weight_decay,
318
+ )
319
+ elif self.optimizer_name == "adam":
320
+ optimizer = torch.optim.Adam(
321
+ params,
322
+ lr=self.lr,
323
+ )
324
+
325
+ return optimizer
326
+
327
+ @torch.no_grad()
328
+ def _plot_mass_and_energy(self, main_batch, metric_samples_batch=None, save_dir="./figures"):
329
+ x0s = main_batch["x0"][0]
330
+ w0s = main_batch["x0"][1]
331
+
332
+ if self.args.manifold:
333
+ if self.metric_clusters == 4:
334
+ branch_sample_pairs = [
335
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
336
+ (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
337
+ (metric_samples_batch[0], metric_samples_batch[3]),
338
+ ]
339
+ elif self.metric_clusters == 3:
340
+ branch_sample_pairs = [
341
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
342
+ (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
343
+ ]
344
+ elif self.metric_clusters == 2 and self.branches == 2:
345
+ branch_sample_pairs = [
346
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
347
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_2 (branch 2)
348
+ ]
349
+ else:
350
+ branch_sample_pairs = [
351
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
352
+ ]
353
+
354
+ batch_size = x0s.shape[0]
355
+ dtype = x0s[0].dtype
356
+
357
+ m0s = torch.zeros_like(w0s, dtype=dtype)
358
+ xt = [x0s.clone() for _ in range(self.branches)]
359
+
360
+ w0_branch = torch.zeros_like(w0s, dtype=dtype)
361
+ w0_branches = []
362
+ w0_branches.append(w0s)
363
+ for _ in range(self.branches - 1):
364
+ w0_branches.append(w0_branch)
365
+
366
+ wt = w0_branches
367
+ mt = [m0s.clone() for _ in range(self.branches)]
368
+
369
+ time_points = []
370
+ mass_over_time = [[] for _ in range(self.branches)]
371
+ energy_over_time = [[] for _ in range(self.branches)]
372
+
373
+ t_span = torch.linspace(0, 1, 101)
374
+ for s, t in zip(t_span[:-1], t_span[1:]):
375
+ time_points.append(t.item())
376
+ time = torch.Tensor([s, t])
377
+
378
+ for i in range(self.branches):
379
+ if self.args.manifold:
380
+ start_samples, end_samples = branch_sample_pairs[i]
381
+ samples = torch.cat([start_samples, end_samples], dim=0)
382
+ else:
383
+ samples = None
384
+
385
+ start_state = (xt[i], wt[i], mt[i])
386
+ xt_next, wt_next, mt_next = self.take_step(time, start_state, i, samples)
387
+
388
+ xt[i] = xt_next[-1].clone().detach()
389
+ wt[i] = wt_next[-1].clone().detach()
390
+ mt[i] = mt_next[-1].clone().detach()
391
+
392
+ mass_over_time[i].append(wt[i].mean().item())
393
+ energy_over_time[i].append(mt[i].mean().item())
394
+
395
+ os.makedirs(os.path.join(save_dir, self.args.data_type), exist_ok=True)
396
+
397
+ # Use tab10 colormap to get visually distinct colors
398
+ if self.args.branches == 3:
399
+ branch_colors = ['#9793F8', '#50B2D7', '#D577FF'] # tuple of RGBs
400
+ else:
401
+ branch_colors = ['#50B2D7', '#D577FF'] # tuple of RGBs
402
+
403
+ # --- Plot Mass ---
404
+ plt.figure(figsize=(8, 5))
405
+ for i in range(self.branches):
406
+ color = branch_colors[i]
407
+ plt.plot(time_points, mass_over_time[i], color=color, linewidth=2.5, label=f"Mass Branch {i}")
408
+ plt.xlabel("Time")
409
+ plt.ylabel("Mass")
410
+ plt.title("Mass Evolution per Branch")
411
+ plt.legend()
412
+ plt.grid(True)
413
+ if self.joint:
414
+ mass_path = os.path.join(save_dir, f"{self.args.data_name}/{self.args.data_name}_joint_mass.png")
415
+ else:
416
+ mass_path = os.path.join(save_dir, f"{self.args.data_name}/{self.args.data_name}_growth_mass.png")
417
+ plt.savefig(mass_path, dpi=300, bbox_inches="tight")
418
+ plt.close()
419
+
420
+ # --- Plot Energy ---
421
+ plt.figure(figsize=(8, 5))
422
+ for i in range(self.branches):
423
+ color = branch_colors[i]
424
+ plt.plot(time_points, energy_over_time[i], color=color, linewidth=2.5, label=f"Energy Branch {i}")
425
+ plt.xlabel("Time")
426
+ plt.ylabel("Energy")
427
+ plt.title("Energy Evolution per Branch")
428
+ plt.legend()
429
+ plt.grid(True)
430
+ if self.joint:
431
+ energy_path = os.path.join(save_dir, f"{self.args.data_name}/{self.args.data_name}_joint_energy.png")
432
+ else:
433
+ energy_path = os.path.join(save_dir, f"{self.args.data_name}/{self.args.data_name}_growth_energy.png")
434
+ plt.savefig(energy_path, dpi=300, bbox_inches="tight")
435
+ plt.close()
436
+
437
+
438
+ class GrowthNetTrainLidar(GrowthNetTrain):
439
+ def test_step(self, batch, batch_idx):
440
+ main_batch = batch["test_samples"][0]
441
+ metric_batch = batch["metric_samples"][0]
442
+
443
+ self._plot_mass_and_energy(main_batch, metric_batch)
444
+
445
+ x0 = main_batch["x0"][0] # [B, D]
446
+ cloud_points = main_batch["dataset"][0] # full dataset, [N, D]
447
+ t_span = torch.linspace(0, 1, 101)
448
+
449
+
450
+ all_trajs = []
451
+
452
+ for i, flow_net in enumerate(self.flow_nets):
453
+ node = NeuralODE(
454
+ flow_model_torch_wrapper(flow_net),
455
+ solver="euler",
456
+ sensitivity="adjoint",
457
+ )
458
+
459
+ with torch.no_grad():
460
+ traj = node.trajectory(x0, t_span).cpu() # [T, B, D]
461
+
462
+ if self.whiten:
463
+ traj_shape = traj.shape
464
+ traj = traj.reshape(-1, 3)
465
+ traj = self.trainer.datamodule.scaler.inverse_transform(
466
+ traj.cpu().detach().numpy()
467
+ ).reshape(traj_shape)
468
+
469
+ traj = torch.tensor(traj)
470
+ traj = torch.transpose(traj, 0, 1) # [B, T, D]
471
+ all_trajs.append(traj)
472
+
473
+ # Inverse-transform the point cloud once
474
+ if self.whiten:
475
+ cloud_points = torch.tensor(
476
+ self.trainer.datamodule.scaler.inverse_transform(
477
+ cloud_points.cpu().detach().numpy()
478
+ )
479
+ )
480
+
481
+ # ===== Plot all trajectories together =====
482
+ fig = plt.figure(figsize=(6, 5))
483
+ ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
484
+ ax.view_init(elev=30, azim=-115, roll=0)
485
+ for i, traj in enumerate(all_trajs):
486
+ plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
487
+ if self.joint:
488
+ plt.savefig('./figures/lidar/joint_lidar_all_branches.png', dpi=300)
489
+ else:
490
+ plt.savefig('./figures/lidar/growth_lidar_all_branches.png', dpi=300)
491
+ plt.close()
492
+
493
+ # ===== Plot each trajectory separately =====
494
+ for i, traj in enumerate(all_trajs):
495
+ fig = plt.figure(figsize=(6, 5))
496
+ ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
497
+ ax.view_init(elev=30, azim=-115, roll=0)
498
+ plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
499
+ if self.joint:
500
+ plt.savefig(f'./figures/lidar/joint_lidar_branch_{i + 1}.png', dpi=300)
501
+ else:
502
+ plt.savefig(f'./figures/lidar/growth_lidar_branch_{i + 1}.png', dpi=300)
503
+ plt.close()
504
+
505
+ class GrowthNetTrainCell(GrowthNetTrain):
506
+ def test_step(self, batch, batch_idx):
507
+ if self.args.data_type in ["scrna", "tahoe"]:
508
+ main_batch = batch[0]["test_samples"][0]
509
+ metric_batch = batch[0]["metric_samples"][0]
510
+ else:
511
+ main_batch = batch["test_samples"][0]
512
+ metric_batch = batch["metric_samples"][0]
513
+
514
+ self._plot_mass_and_energy(main_batch, metric_batch)
branchsbm/branch_interpolant_train.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("./BranchSBM")
3
+ import torch
4
+ import pytorch_lightning as pl
5
+ from branchsbm.ema import EMA
6
+ import itertools
7
+ from utils import wasserstein_distance, plot_lidar
8
+ import matplotlib.pyplot as plt
9
+
10
+ class BranchInterpolantTrain(pl.LightningModule):
11
+ def __init__(
12
+ self,
13
+ flow_matcher,
14
+ args,
15
+ skipped_time_points: list = None,
16
+ ot_sampler=None,
17
+
18
+ state_cost=None,
19
+ data_manifold_metric=None,
20
+ ):
21
+ super().__init__()
22
+ self.save_hyperparameters()
23
+ self.args = args
24
+
25
+ self.flow_matcher = flow_matcher
26
+
27
+ # list of geopath nets
28
+ self.geopath_nets = flow_matcher.geopath_nets
29
+ self.branches = len(self.geopath_nets)
30
+ self.metric_clusters = args.metric_clusters
31
+
32
+ self.ot_sampler = ot_sampler
33
+ self.skipped_time_points = skipped_time_points if skipped_time_points else []
34
+ self.optimizer_name = args.geopath_optimizer
35
+ self.lr = args.geopath_lr
36
+ self.weight_decay = args.geopath_weight_decay
37
+ self.args = args
38
+ self.multiply_validation = 4
39
+
40
+ self.first_loss = None
41
+ self.timesteps = None
42
+ self.computing_reference_loss = False
43
+
44
+ # updates
45
+ self.state_cost = state_cost
46
+ self.data_manifold_metric = data_manifold_metric
47
+ self.whiten = args.whiten
48
+
49
+ def forward(self, x0, x1, t, branch_idx):
50
+ # return specific branch interpolant
51
+ return self.geopath_nets[branch_idx](x0, x1, t)
52
+
53
+ def on_train_start(self):
54
+ self.first_loss = self.compute_initial_loss()
55
+ print("first loss")
56
+ print(self.first_loss)
57
+
58
+ # to edit
59
+ def compute_initial_loss(self):
60
+ # Set all GeoPath networks to eval mode
61
+ for net in self.geopath_nets:
62
+ net.train(mode=False)
63
+
64
+ total_loss = 0
65
+ total_count = 0
66
+ with torch.enable_grad():
67
+ self.t_val = []
68
+ for i in range(
69
+ self.trainer.datamodule.num_timesteps - len(self.skipped_time_points)
70
+ ):
71
+ self.t_val.append(
72
+ torch.rand(
73
+ self.trainer.datamodule.batch_size * self.multiply_validation,
74
+ requires_grad=True,
75
+ )
76
+ )
77
+ self.computing_reference_loss = True
78
+ with torch.no_grad():
79
+ old_alpha = self.flow_matcher.alpha
80
+ self.flow_matcher.alpha = 0
81
+ for batch in self.trainer.datamodule.train_dataloader():
82
+
83
+ self.timesteps = torch.linspace(
84
+ 0.0, 1.0, len(batch[0]["train_samples"][0])
85
+ )
86
+
87
+ loss = self._compute_loss(
88
+ batch[0]["train_samples"][0],
89
+ batch[0]["metric_samples"][0],
90
+ )
91
+ print("initial loss")
92
+ print(loss)
93
+ total_loss += loss.item()
94
+ total_count += 1
95
+ self.flow_matcher.alpha = old_alpha
96
+
97
+ self.computing_reference_loss = False
98
+
99
+ # Set all GeoPath networks back to training mode
100
+ for net in self.geopath_nets:
101
+ net.train(mode=True)
102
+ return total_loss / total_count if total_count > 0 else 1.0
103
+
104
+ def _compute_loss(self, main_batch, metric_samples_batch=None):
105
+
106
+ x0s = [main_batch["x0"][0]]
107
+ w0s = [main_batch["x0"][1]]
108
+
109
+ x1s_list = []
110
+ w1s_list = []
111
+
112
+ if self.branches > 1:
113
+ for i in range(self.branches):
114
+ x1s_list.append([main_batch[f"x1_{i+1}"][0]])
115
+ w1s_list.append([main_batch[f"x1_{i+1}"][1]])
116
+ else:
117
+ x1s_list.append([main_batch["x1"][0]])
118
+ w1s_list.append([main_batch["x1"][1]])
119
+
120
+ if self.args.manifold:
121
+ #changed
122
+ if self.metric_clusters == 4:
123
+ branch_sample_pairs = [
124
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
125
+ (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
126
+ (metric_samples_batch[0], metric_samples_batch[3]),
127
+ ]
128
+ elif self.metric_clusters == 3:
129
+ branch_sample_pairs = [
130
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
131
+ (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
132
+ ]
133
+ elif self.metric_clusters == 2 and self.branches == 2:
134
+ branch_sample_pairs = [
135
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
136
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_2 (branch 2)
137
+ ]
138
+ else:
139
+ branch_sample_pairs = [
140
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
141
+ ]
142
+ """samples0, samples1, samples2 = (
143
+ metric_samples_batch[0],
144
+ metric_samples_batch[1],
145
+ metric_samples_batch[2]
146
+ )"""
147
+
148
+ assert len(x1s_list) == self.branches, "Mismatch between x1s_list and expected branches"
149
+
150
+ # compute sum of velocities for each branch
151
+ loss = 0
152
+ velocities = []
153
+ for branch_idx in range(self.branches):
154
+
155
+ ts, xts, uts = self._process_flow(x0s, x1s_list[branch_idx], branch_idx)
156
+
157
+ for i in range(len(ts)):
158
+ # calculate kinetic and potential energy of the predicted interpolant
159
+ if self.args.manifold:
160
+ start_samples, end_samples = branch_sample_pairs[branch_idx]
161
+
162
+ samples = torch.cat([start_samples, end_samples], dim=0)
163
+ #print("metric sample shape")
164
+ #print(samples.shape)
165
+ vel, _, _ = self.data_manifold_metric.calculate_velocity(
166
+ xts[i], uts[i], samples, i
167
+ )
168
+ else:
169
+ vel = torch.sqrt((uts[i]**2).sum(dim =-1) + self.state_cost(xts[i]))
170
+ #vel = (uts[i]**2).sum(dim =-1)
171
+
172
+ velocities.append(vel)
173
+
174
+ loss = torch.mean(torch.cat(velocities) ** 2)
175
+
176
+ self.log(
177
+ "BranchPathNet/mean_velocity_geopath",
178
+ loss,
179
+ on_step=False,
180
+ on_epoch=True,
181
+ prog_bar=True,
182
+ )
183
+
184
+ return loss
185
+
186
+ def _process_flow(self, x0s, x1s, branch_idx):
187
+ ts, xts, uts = [], [], []
188
+ t_start = self.timesteps[0]
189
+ i_start = 0
190
+
191
+ for i, (x0, x1) in enumerate(zip(x0s, x1s)):
192
+ x0, x1 = torch.squeeze(x0), torch.squeeze(x1)
193
+ if self.trainer.validating or self.computing_reference_loss:
194
+ repeat_tuple = (self.multiply_validation, 1) + (1,) * (
195
+ len(x0.shape) - 2
196
+ )
197
+ x0 = x0.repeat(repeat_tuple)
198
+ x1 = x1.repeat(repeat_tuple)
199
+
200
+ if self.ot_sampler is not None:
201
+ x0, x1 = self.ot_sampler.sample_plan(
202
+ x0,
203
+ x1,
204
+ replace=True,
205
+ )
206
+ if self.skipped_time_points and i + 1 >= self.skipped_time_points[0]:
207
+ t_start_next = self.timesteps[i + 2]
208
+ else:
209
+ t_start_next = self.timesteps[i + 1]
210
+
211
+ t = None
212
+ if self.trainer.validating or self.computing_reference_loss:
213
+ t = self.t_val[i]
214
+
215
+ t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(
216
+ x0, x1, t_start, t_start_next, branch_idx, training_geopath_net=True, t=t
217
+ )
218
+ ts.append(t)
219
+ xts.append(xt)
220
+ uts.append(ut)
221
+ t_start = t_start_next
222
+
223
+ return ts, xts, uts
224
+
225
+ def training_step(self, batch, batch_idx):
226
+ if self.args.data_type in ["scrna", "tahoe"]:
227
+ main_batch = batch[0]["train_samples"][0]
228
+ metric_batch = batch[0]["metric_samples"][0]
229
+ else:
230
+ main_batch = batch["train_samples"][0]
231
+ metric_batch = batch["metric_samples"][0]
232
+
233
+ tangential_velocity_loss = self._compute_loss(main_batch, metric_batch)
234
+
235
+ if self.first_loss:
236
+ tangential_velocity_loss = tangential_velocity_loss / self.first_loss
237
+
238
+ self.log(
239
+ "BranchPathNet/mean_geopath_geopath",
240
+ (self.flow_matcher.geopath_net_output.abs().mean()),
241
+ on_step=False,
242
+ on_epoch=True,
243
+ prog_bar=True,
244
+ )
245
+
246
+ self.log(
247
+ "BranchPathNet/train_loss_geopath",
248
+ tangential_velocity_loss,
249
+ on_step=True,
250
+ on_epoch=True,
251
+ prog_bar=True,
252
+ logger=True,
253
+ )
254
+
255
+ return tangential_velocity_loss
256
+
257
+ def validation_step(self, batch, batch_idx):
258
+ if self.args.data_type in ["scrna", "tahoe"]:
259
+ main_batch = batch[0]["val_samples"][0]
260
+ metric_batch = batch[0]["metric_samples"][0]
261
+ else:
262
+ main_batch = batch["val_samples"][0]
263
+ metric_batch = batch["metric_samples"][0]
264
+
265
+ tangential_velocity_loss = self._compute_loss(main_batch, metric_batch)
266
+ if self.first_loss:
267
+ tangential_velocity_loss = tangential_velocity_loss / self.first_loss
268
+
269
+ self.log(
270
+ "BranchPathNet/val_loss_geopath",
271
+ tangential_velocity_loss,
272
+ on_step=False,
273
+ on_epoch=True,
274
+ prog_bar=True,
275
+ logger=True,
276
+ )
277
+ return tangential_velocity_loss
278
+
279
+
280
+ def test_step(self, batch, batch_idx):
281
+ main_batch = batch["test_samples"][0]
282
+ metric_batch = batch["metric_samples"][0]
283
+
284
+ x0 = main_batch["x0"][0] # [B, D]
285
+ cloud_points = main_batch["dataset"][0] # full dataset, [N, D]
286
+
287
+ x0 = x0.to(self.device)
288
+ cloud_points = cloud_points.to(self.device)
289
+
290
+ t_vals = [0.25, 0.5, 0.75]
291
+ t_labels = ["t=1/4", "t=1/2", "t=3/4"]
292
+
293
+ colors = {
294
+ "x0": "#4D176C",
295
+ "t=1/4": "#5C3B9D",
296
+ "t=1/2": "#6172B9",
297
+ "t=3/4": "#AC4E51",
298
+ "x1": "#771F4F",
299
+ }
300
+
301
+ # Unwhiten cloud points if needed
302
+ if self.whiten:
303
+ cloud_points = torch.tensor(
304
+ self.trainer.datamodule.scaler.inverse_transform(cloud_points.cpu().numpy())
305
+ )
306
+
307
+ for i in range(self.branches):
308
+ geopath = self.geopath_nets[i]
309
+ x1_key = f"x1_{i + 1}"
310
+ if x1_key not in main_batch:
311
+ print(f"Skipping branch {i + 1}: no final distribution {x1_key}")
312
+ continue
313
+
314
+ x1 = main_batch[x1_key][0].to(self.device)
315
+ print(x1.shape)
316
+ print(x0.shape)
317
+ interpolated_points = []
318
+ with torch.no_grad():
319
+ for t_scalar in t_vals:
320
+ t_tensor = torch.full((x0.shape[0], 1), t_scalar, device=self.device) # [B, 1]
321
+ xt = geopath(x0, x1, t_tensor).cpu() # [B, D]
322
+ if self.whiten:
323
+ xt = torch.tensor(
324
+ self.trainer.datamodule.scaler.inverse_transform(xt.numpy())
325
+ )
326
+ interpolated_points.append(xt)
327
+
328
+ if self.whiten:
329
+ x0_plot = torch.tensor(
330
+ self.trainer.datamodule.scaler.inverse_transform(x0.cpu().numpy())
331
+ )
332
+ x1_plot = torch.tensor(
333
+ self.trainer.datamodule.scaler.inverse_transform(x1.cpu().numpy())
334
+ )
335
+ else:
336
+ x0_plot = x0.cpu()
337
+ x1_plot = x1.cpu()
338
+
339
+ # Plot
340
+ fig = plt.figure(figsize=(6, 5))
341
+ ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
342
+ ax.view_init(elev=30, azim=-115, roll=0)
343
+ plot_lidar(ax, cloud_points)
344
+
345
+ # Initial x₀
346
+ ax.scatter(
347
+ x0_plot[:, 0], x0_plot[:, 1], x0_plot[:, 2],
348
+ s=15, alpha=1.0, color=colors["x0"], label="x₀", depthshade=True,
349
+ edgecolors="white",
350
+ linewidths=0.3
351
+ )
352
+
353
+ # Interpolated points
354
+ for xt, t_label in zip(interpolated_points, t_labels):
355
+ ax.scatter(
356
+ xt[:, 0], xt[:, 1], xt[:, 2],
357
+ s=15, alpha=1.0, color=colors[t_label], label=t_label, depthshade=True,
358
+ edgecolors="white",
359
+ linewidths=0.3
360
+ )
361
+
362
+ # Final x₁
363
+ ax.scatter(
364
+ x1_plot[:, 0], x1_plot[:, 1], x1_plot[:, 2],
365
+ s=15, alpha=1.0, color=colors["x1"], label="x₁", depthshade=True,
366
+ edgecolors="white",
367
+ linewidths=0.3
368
+ )
369
+
370
+ ax.legend()
371
+ save_path = f"/raid/st512/branchsbm/figures/{self.args.data_type}/lidar_geopath_branch_{i+1}.png"
372
+ plt.savefig(save_path, dpi=300)
373
+ plt.close()
374
+
375
+ def optimizer_step(self, *args, **kwargs):
376
+ super().optimizer_step(*args, **kwargs)
377
+ if isinstance(self.geopath_nets, EMA):
378
+ self.geopath_nets.update_ema()
379
+
380
+ def configure_optimizers(self):
381
+ if self.optimizer_name == "adam":
382
+ """optimizer = torch.optim.Adam(
383
+ self.geopath_nets.parameters(),
384
+ lr=self.lr,
385
+ )"""
386
+ optimizer = torch.optim.Adam(
387
+ itertools.chain(*[net.parameters() for net in self.geopath_nets]), lr=self.lr
388
+ )
389
+ elif self.optimizer_name == "adamw":
390
+ """optimizer = torch.optim.AdamW(
391
+ self.geopath_nets.parameters(),
392
+ lr=self.lr,
393
+ weight_decay=self.weight_decay,
394
+ )"""
395
+ optimizer = torch.optim.AdamW(
396
+ itertools.chain(*[net.parameters() for net in self.geopath_nets]), lr=self.lr
397
+ )
398
+ return optimizer
branchsbm/branchsbm.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("./BranchSBM")
3
+ import torch
4
+ from torchcfm.conditional_flow_matching import ConditionalFlowMatcher, pad_t_like_x
5
+ import torch.nn as nn
6
+
7
+ class BranchSBM(ConditionalFlowMatcher):
8
+ def __init__(
9
+ self, geopath_nets: nn.ModuleList = None, alpha: float = 1.0, *args, **kwargs
10
+ ):
11
+ super().__init__(*args, **kwargs)
12
+ self.alpha = alpha
13
+ self.geopath_nets = geopath_nets
14
+ if self.alpha != 0:
15
+ assert (
16
+ geopath_nets is not None
17
+ ), "GeoPath model must be provided if alpha != 0"
18
+
19
+ self.branches = len(geopath_nets)
20
+
21
+ def gamma(self, t, t_min, t_max):
22
+ return (
23
+ 1.0
24
+ - ((t - t_min) / (t_max - t_min)) ** 2
25
+ - ((t_max - t) / (t_max - t_min)) ** 2
26
+ )
27
+
28
+ def d_gamma(self, t, t_min, t_max):
29
+ return 2 * (-2 * t + t_max + t_min) / (t_max - t_min) ** 2
30
+
31
+ def compute_mu_t(self, x0, x1, t, t_min, t_max, branch_idx):
32
+ assert branch_idx < self.branches, "Index out of bounds"
33
+
34
+ with torch.enable_grad():
35
+ t = pad_t_like_x(t, x0)
36
+ if self.alpha == 0:
37
+ return (t_max - t) / (t_max - t_min) * x0 + (t - t_min) / (
38
+ t_max - t_min
39
+ ) * x1
40
+
41
+ # compute value for specific branch
42
+ self.geopath_net_output = self.geopath_nets[branch_idx](x0, x1, t)
43
+ if self.geopath_nets[branch_idx].time_geopath:
44
+ self.doutput_dt = torch.autograd.grad(
45
+ self.geopath_net_output,
46
+ t,
47
+ grad_outputs=torch.ones_like(self.geopath_net_output),
48
+ create_graph=False,
49
+ retain_graph=True,
50
+ )[0]
51
+ return (
52
+ (t_max - t) / (t_max - t_min) * x0
53
+ + (t - t_min) / (t_max - t_min) * x1
54
+ + self.gamma(t, t_min, t_max) * self.geopath_net_output
55
+ )
56
+
57
+ def sample_xt(self, x0, x1, t, epsilon, t_min, t_max, branch_idx):
58
+ assert branch_idx < self.branches, "Index out of bounds"
59
+ mu_t = self.compute_mu_t(x0, x1, t, t_min, t_max, branch_idx)
60
+ sigma_t = self.compute_sigma_t(t)
61
+ sigma_t = pad_t_like_x(sigma_t, x0)
62
+ return mu_t + sigma_t * epsilon
63
+
64
+ def sample_location_and_conditional_flow(
65
+ self,
66
+ x0,
67
+ x1,
68
+ t_min,
69
+ t_max,
70
+ branch_idx,
71
+ training_geopath_net=False,
72
+ midpoint_only=False,
73
+ t=None,
74
+ ):
75
+
76
+ self.training_geopath_net = training_geopath_net
77
+ with torch.enable_grad():
78
+ if t is None:
79
+ t = torch.rand(x0.shape[0], requires_grad=True)
80
+ t = t.type_as(x0)
81
+ t = t * (t_max - t_min) + t_min
82
+ if midpoint_only:
83
+ t = (t_max + t_min) / 2 * torch.ones_like(t).type_as(x0)
84
+
85
+ assert len(t) == x0.shape[0], "t has to have batch size dimension"
86
+
87
+ eps = self.sample_noise_like(x0)
88
+
89
+ # compute xt and ut for branch_idx
90
+ xt = self.sample_xt(x0, x1, t, eps, t_min, t_max, branch_idx)
91
+ ut = self.compute_conditional_flow(x0, x1, t, xt, t_min, t_max, branch_idx)
92
+
93
+ return t, xt, ut
94
+
95
+ def compute_conditional_flow(self, x0, x1, t, xt, t_min, t_max, branch_idx):
96
+ del xt
97
+ t = pad_t_like_x(t, x0)
98
+ if self.alpha == 0:
99
+ return (x1 - x0) / (t_max - t_min)
100
+
101
+ return (
102
+ (x1 - x0) / (t_max - t_min)
103
+ + self.d_gamma(t, t_min, t_max) * self.geopath_net_output
104
+ + (
105
+ self.gamma(t, t_min, t_max) * self.doutput_dt
106
+ if self.geopath_nets[branch_idx].time_geopath
107
+ else 0
108
+ )
109
+ )
branchsbm/ema.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class EMA(torch.nn.Module):
4
+ def __init__(self, model: torch.nn.Module, decay: float = 0.999):
5
+ super().__init__()
6
+ self.model = model
7
+ self.decay = decay
8
+ if hasattr(self.model, "time_geopath"):
9
+ self.time_geopath = self.model.time_geopath
10
+
11
+ # Put this in a buffer so that it gets included in the state dict
12
+ self.register_buffer("num_updates", torch.tensor(0))
13
+
14
+ self.shadow_params = torch.nn.ParameterList(
15
+ [
16
+ torch.nn.Parameter(p.clone().detach(), requires_grad=False)
17
+ for p in model.parameters()
18
+ if p.requires_grad
19
+ ]
20
+ )
21
+ self.backup_params = []
22
+
23
+ def train(self, mode: bool):
24
+ if self.training and mode == False:
25
+ # Switching from train mode to eval mode. Backup the model parameters and
26
+ # overwrite with shadow params
27
+ self.backup()
28
+ self.copy_to_model()
29
+ elif not self.training and mode == True:
30
+ # Switching from eval to train mode. Restore the `backup_params`
31
+ self.restore_to_model()
32
+
33
+ super().train(mode)
34
+
35
+ def update_ema(self):
36
+ self.num_updates += 1
37
+ num_updates = self.num_updates.item()
38
+ decay = min(self.decay, (1 + num_updates) / (10 + num_updates))
39
+ with torch.no_grad():
40
+ params = [p for p in self.model.parameters() if p.requires_grad]
41
+ for shadow, param in zip(self.shadow_params, params):
42
+ shadow.sub_((1 - decay) * (shadow - param))
43
+
44
+ def forward(self, *args, **kwargs):
45
+ return self.model(*args, **kwargs)
46
+
47
+ def copy_to_model(self):
48
+ # copy the shadow (ema) parameters to the model
49
+ params = [p for p in self.model.parameters() if p.requires_grad]
50
+ for shaddow, param in zip(self.shadow_params, params):
51
+ param.data.copy_(shaddow.data)
52
+
53
+ def backup(self):
54
+ # Backup the current model parameters
55
+ if len(self.backup_params) > 0:
56
+ for p, b in zip(self.model.parameters(), self.backup_params):
57
+ b.data.copy_(p.data)
58
+ else:
59
+ self.backup_params = [param.clone() for param in self.model.parameters()]
60
+
61
+ def restore_to_model(self):
62
+ # Restores the backed up parameters to the model.
63
+ for param, backup in zip(self.model.parameters(), self.backup_params):
64
+ param.data.copy_(backup.data)
configs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
configs/experiment/cell_single_branch.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: "scrna"
2
+ data_name: "mouse"
3
+ dim: 2
4
+ whiten: false
5
+ t_exclude: []
6
+ velocity_metric: "land"
7
+ gammas: [0.125]
8
+ rho: 0.001
9
+ branchsbm: true
10
+ seeds: [42]
11
+ patience_geopath: 50
12
+ time_geopath: true
configs/experiment/clonidine_100D.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: "tahoe"
2
+ data_name: "clonidine100D"
3
+ accelerator: "gpu"
4
+ hidden_dims_geopath: [1024, 1024, 1024]
5
+ hidden_dims_flow: [1024, 1024, 1024]
6
+ hidden_dims_growth: [1024, 1024, 1024]
7
+ dim: 100
8
+ t_exclude: []
9
+ time_geopath: true
10
+ whiten: false
11
+ velocity_metric: "rbf"
12
+ metric_patience: 25
13
+ patience: 25
14
+ n_centers: 300
15
+ kappa: 2
16
+ rho: -2.75
17
+ alpha_metric: 1
18
+ metric_epochs: 200
19
+ branchsbm: true
20
+ seeds: [42]
21
+ branches: 2
22
+ metric_clusters: 3
configs/experiment/clonidine_150D.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: "tahoe"
2
+ data_name: "clonidine150D"
3
+ accelerator: "gpu"
4
+ hidden_dims_geopath: [1024, 1024, 1024]
5
+ hidden_dims_flow: [1024, 1024, 1024]
6
+ hidden_dims_growth: [1024, 1024, 1024]
7
+ dim: 150
8
+ t_exclude: []
9
+ time_geopath: true
10
+ whiten: false
11
+ velocity_metric: "rbf"
12
+ metric_patience: 25
13
+ patience: 25
14
+ n_centers: 300
15
+ kappa: 3
16
+ rho: -2.75
17
+ alpha_metric: 1
18
+ metric_epochs: 400
19
+ branchsbm: true
20
+ seeds: [42]
21
+ branches: 2
22
+ metric_clusters: 3
configs/experiment/clonidine_50D.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: "tahoe"
2
+ data_name: "clonidine50D"
3
+ accelerator: "gpu"
4
+ hidden_dims_geopath: [1024, 1024, 1024]
5
+ hidden_dims_flow: [1024, 1024, 1024]
6
+ hidden_dims_growth: [1024, 1024, 1024]
7
+ dim: 50
8
+ t_exclude: []
9
+ time_geopath: true
10
+ whiten: false
11
+ velocity_metric: "rbf"
12
+ metric_patience: 25
13
+ patience: 25
14
+ n_centers: 150
15
+ kappa: 1.5
16
+ rho: -2.75
17
+ alpha_metric: 1
18
+ metric_epochs: 200
19
+ branchsbm: true
20
+ seeds: [42]
21
+ branches: 2
22
+ metric_clusters: 3
configs/experiment/clonidine_50Dsingle.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: "tahoe"
2
+ data_name: "clonidine50Dsingle"
3
+ accelerator: "gpu"
4
+ hidden_dims_geopath: [1024, 1024, 1024]
5
+ hidden_dims_flow: [1024, 1024, 1024]
6
+ hidden_dims_growth: [1024, 1024, 1024]
7
+ dim: 50
8
+ t_exclude: []
9
+ time_geopath: true
10
+ whiten: false
11
+ velocity_metric: "rbf"
12
+ metric_patience: 25
13
+ patience: 25
14
+ n_centers: 150
15
+ kappa: 1.5
16
+ rho: -2.75
17
+ alpha_metric: 1
18
+ metric_epochs: 200
19
+ branchsbm: true
20
+ seeds: [42]
21
+ branches: 1
22
+ metric_clusters: 2
configs/experiment/lidar.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: "lidar"
2
+ data_name: "lidar"
3
+ dim: 3
4
+ whiten: true
5
+ t_exclude: []
6
+ velocity_metric: "land"
7
+ gammas: [0.125]
8
+ rho: 0.001
9
+ branchsbm: true
10
+ seeds: [42]
11
+ patience_geopath: 50
12
+ time_geopath: true
13
+ branches: 2
14
+ metric_clusters: 3
configs/experiment/lidar_single.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: "lidar"
2
+ data_name: "lidarsingle"
3
+ dim: 3
4
+ whiten: true
5
+ t_exclude: []
6
+ velocity_metric: "land"
7
+ gammas: [0.125]
8
+ rho: 0.001
9
+ branchsbm: true
10
+ seeds: [42]
11
+ patience_geopath: 50
12
+ time_geopath: true
13
+ branches: 1
14
+ metric_clusters: 2
configs/experiment/mouse.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: "scrna"
2
+ data_name: "mouse"
3
+ hidden_dims_geopath: [1024, 1024, 1024]
4
+ hidden_dims_flow: [1024, 1024, 1024]
5
+ hidden_dims_growth: [1024, 1024, 1024]
6
+ dim: 2
7
+ whiten: false
8
+ t_exclude: []
9
+ velocity_metric: "land"
10
+ gammas: [0.125]
11
+ rho: 0.001
12
+ branchsbm: true
13
+ seeds: [42]
14
+ patience_geopath: 50
15
+ time_geopath: true
16
+ branches: 2
17
+ metric_clusters: 3
configs/experiment/trametinib.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: "tahoe"
2
+ data_name: "trametinib"
3
+ accelerator: "gpu"
4
+ hidden_dims_geopath: [1024, 1024, 1024]
5
+ hidden_dims_flow: [1024, 1024, 1024]
6
+ hidden_dims_growth: [1024, 1024, 1024]
7
+ dim: 50
8
+ t_exclude: []
9
+ time_geopath: true
10
+ whiten: false
11
+ velocity_metric: "rbf"
12
+ metric_patience: 25
13
+ patience: 25
14
+ n_centers: 150
15
+ kappa: 1.5
16
+ rho: -2.75
17
+ alpha_metric: 1
18
+ metric_epochs: 200
19
+ branchsbm: true
20
+ seeds: [42]
21
+ branches: 3
22
+ metric_clusters: 4
configs/experiment/trametinib_single.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: "tahoe"
2
+ data_name: "trametinibsingle"
3
+ accelerator: "gpu"
4
+ hidden_dims_geopath: [1024, 1024, 1024]
5
+ hidden_dims_flow: [1024, 1024, 1024]
6
+ hidden_dims_growth: [1024, 1024, 1024]
7
+ dim: 50
8
+ t_exclude: []
9
+ time_geopath: true
10
+ whiten: false
11
+ velocity_metric: "rbf"
12
+ metric_patience: 25
13
+ patience: 25
14
+ n_centers: 150
15
+ kappa: 1.5
16
+ rho: -2.75
17
+ alpha_metric: 1
18
+ metric_epochs: 200
19
+ branchsbm: true
20
+ seeds: [42]
21
+ branches: 1
22
+ metric_clusters: 2
dataloaders/.DS_Store ADDED
Binary file (6.15 kB). View file
 
dataloaders/clonidine_data.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ sys.argv = ['']
4
+ from sklearn.preprocessing import StandardScaler
5
+ import pytorch_lightning as pl
6
+ from torch.utils.data import DataLoader
7
+ from lightning.pytorch.utilities.combined_loader import CombinedLoader
8
+ import pandas as pd
9
+ import numpy as np
10
+ from functools import partial
11
+ from scipy.spatial import cKDTree
12
+ from sklearn.cluster import KMeans
13
+ from torch.utils.data import TensorDataset
14
+
15
+ #from train.parsers_tahoe import parse_args
16
+ #args = parse_args()
17
+
18
+ class DrugResponseDataModule(pl.LightningDataModule):
19
+ def __init__(self, args):
20
+ super().__init__()
21
+ self.save_hyperparameters()
22
+
23
+ self.batch_size = args.batch_size
24
+ self.max_dim = args.dim
25
+ self.whiten = args.whiten
26
+ self.split_ratios = args.split_ratios
27
+
28
+ # Path to your combined data
29
+ self.data_path = "/raid/st512/branchsbm/data/pca_and_leiden_labels.csv"
30
+ self.num_timesteps = 2
31
+ self.args = args
32
+ self._prepare_data()
33
+
34
+ def _prepare_data(self):
35
+ df = pd.read_csv(self.data_path, comment='#')
36
+ df = df.iloc[:, 1:]
37
+ df = df.replace('', np.nan)
38
+ pc_cols = df.columns[:50]
39
+ for col in pc_cols:
40
+ df[col] = pd.to_numeric(df[col], errors='coerce')
41
+ leiden_dmso_col = 'leiden_DMSO_TF_0.0uM'
42
+ leiden_clonidine_col = 'leiden_Clonidine (hydrochloride)_5.0uM'
43
+
44
+ dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column
45
+ clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column
46
+
47
+ dmso_data = df[dmso_mask].copy()
48
+ clonidine_data = df[clonidine_mask].copy()
49
+
50
+ top_clonidine_clusters = ['0.0', '4.0']
51
+
52
+ x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]]
53
+ x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]]
54
+ x1_1_coords = x1_1_data[pc_cols].values
55
+ x1_2_coords = x1_2_data[pc_cols].values
56
+ x1_1_coords = x1_1_coords.astype(float)
57
+ x1_2_coords = x1_2_coords.astype(float)
58
+
59
+ target_size = min(len(x1_1_coords), len(x1_2_coords))
60
+
61
+ # Sample endpoint clusters to target size
62
+ np.random.seed(42)
63
+ if len(x1_1_coords) > target_size:
64
+ idx1 = np.random.choice(len(x1_1_coords), target_size, replace=False)
65
+ x1_1_coords = x1_1_coords[idx1]
66
+
67
+ if len(x1_2_coords) > target_size:
68
+ idx2 = np.random.choice(len(x1_2_coords), target_size, replace=False)
69
+ x1_2_coords = x1_2_coords[idx2]
70
+
71
+ dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts()
72
+
73
+ # DMSO
74
+ largest_dmso_cluster = dmso_cluster_counts.index[0]
75
+ dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster]
76
+
77
+ dmso_coords = dmso_cluster_data[pc_cols].values
78
+
79
+ # Random sampling from largest DMSO cluster to match target size
80
+ np.random.seed(42)
81
+ if len(dmso_coords) >= target_size:
82
+ idx0 = np.random.choice(len(dmso_coords), target_size, replace=False)
83
+ x0_coords = dmso_coords[idx0]
84
+ else:
85
+ # If largest cluster is smaller than target, use all of it and pad with other DMSO cells
86
+ remaining_needed = target_size - len(dmso_coords)
87
+ other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster]
88
+ other_dmso_coords = other_dmso_data[pc_cols].values
89
+
90
+ if len(other_dmso_coords) >= remaining_needed:
91
+ idx_other = np.random.choice(len(other_dmso_coords), remaining_needed, replace=False)
92
+ x0_coords = np.vstack([dmso_coords, other_dmso_coords[idx_other]])
93
+ else:
94
+ # Use all available DMSO cells and reduce target size
95
+ all_dmso_coords = dmso_data[pc_cols].values
96
+ target_size = min(target_size, len(all_dmso_coords))
97
+ idx0 = np.random.choice(len(all_dmso_coords), target_size, replace=False)
98
+ x0_coords = all_dmso_coords[idx0]
99
+
100
+ # Also resample endpoint clusters to match final target size
101
+ if len(x1_1_coords) > target_size:
102
+ idx1 = np.random.choice(len(x1_1_coords), target_size, replace=False)
103
+ x1_1_coords = x1_1_coords[idx1]
104
+
105
+ if len(x1_2_coords) > target_size:
106
+ idx2 = np.random.choice(len(x1_2_coords), target_size, replace=False)
107
+ x1_2_coords = x1_2_coords[idx2]
108
+
109
+ self.n_samples = target_size
110
+
111
+ x0 = torch.tensor(x0_coords, dtype=torch.float32)
112
+ x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
113
+ x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
114
+
115
+ self.coords_t0 = x0
116
+ self.coords_t1 = torch.cat([x1_1, x1_2], dim=0)
117
+ self.time_labels = np.concatenate([
118
+ np.zeros(len(self.coords_t0)), # t=0
119
+ np.ones(len(self.coords_t1)), # t=1
120
+ ])
121
+
122
+ split_index = int(target_size * self.split_ratios[0])
123
+
124
+ if target_size - split_index < self.batch_size:
125
+ split_index = target_size - self.batch_size
126
+
127
+ train_x0 = x0[:split_index]
128
+ val_x0 = x0[split_index:]
129
+ train_x1_1 = x1_1[:split_index]
130
+ val_x1_1 = x1_1[split_index:]
131
+ train_x1_2 = x1_2[:split_index]
132
+ val_x1_2 = x1_2[split_index:]
133
+
134
+ self.val_x0 = val_x0
135
+
136
+ train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
137
+ train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.5)
138
+ train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.5)
139
+
140
+ val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
141
+ val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.5)
142
+ val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.5)
143
+
144
+ self.train_dataloaders = {
145
+ "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
146
+ "x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
147
+ "x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
148
+ }
149
+
150
+ self.val_dataloaders = {
151
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
152
+ "x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
153
+ "x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
154
+ }
155
+
156
+ all_coords = df[pc_cols].dropna().values.astype(float)
157
+ self.dataset = torch.tensor(all_coords, dtype=torch.float32)
158
+ self.tree = cKDTree(all_coords)
159
+
160
+ self.test_dataloaders = {
161
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
162
+ "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
163
+ }
164
+
165
+ # Metric samples
166
+ km_all = KMeans(n_clusters=3, random_state=0).fit(self.dataset.numpy())
167
+ cluster_labels = km_all.labels_
168
+
169
+ cluster_0_mask = cluster_labels == 0
170
+ cluster_1_mask = cluster_labels == 1
171
+ cluster_2_mask = cluster_labels == 2
172
+
173
+ samples = self.dataset.cpu().numpy()
174
+
175
+ cluster_0_data = samples[cluster_0_mask]
176
+ cluster_1_data = samples[cluster_1_mask]
177
+ cluster_2_data = samples[cluster_2_mask]
178
+
179
+ self.metric_samples_dataloaders = [
180
+ DataLoader(
181
+ torch.tensor(cluster_2_data, dtype=torch.float32),
182
+ batch_size=cluster_2_data.shape[0],
183
+ shuffle=False,
184
+ drop_last=False,
185
+ ),
186
+ DataLoader(
187
+ torch.tensor(cluster_0_data, dtype=torch.float32),
188
+ batch_size=cluster_0_data.shape[0],
189
+ shuffle=False,
190
+ drop_last=False,
191
+ ),
192
+
193
+ DataLoader(
194
+ torch.tensor(cluster_1_data, dtype=torch.float32),
195
+ batch_size=cluster_1_data.shape[0],
196
+ shuffle=False,
197
+ drop_last=False,
198
+ ),
199
+ ]
200
+
201
+ def train_dataloader(self):
202
+ combined_loaders = {
203
+ "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
204
+ "metric_samples": CombinedLoader(
205
+ self.metric_samples_dataloaders, mode="min_size"
206
+ ),
207
+ }
208
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
209
+
210
+ def val_dataloader(self):
211
+ combined_loaders = {
212
+ "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
213
+ "metric_samples": CombinedLoader(
214
+ self.metric_samples_dataloaders, mode="min_size"
215
+ ),
216
+ }
217
+
218
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
219
+
220
+
221
+
222
+ def test_dataloader(self):
223
+ combined_loaders = {
224
+ "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
225
+ "metric_samples": CombinedLoader(
226
+ self.metric_samples_dataloaders, mode="min_size"
227
+ ),
228
+ }
229
+
230
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
231
+
232
+ def get_manifold_proj(self, points):
233
+ """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
234
+ return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
235
+
236
+ @staticmethod
237
+ def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
238
+ """
239
+ Apply local smoothing based on k-nearest neighbors in the full dataset
240
+ This replaces the plane projection for 2D manifold regularization
241
+ """
242
+ points_np = x.detach().cpu().numpy()
243
+ _, idx = tree.query(points_np, k=k)
244
+ nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
245
+
246
+ # Compute weighted average of neighbors
247
+ dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
248
+ weights = torch.exp(-dists / temp)
249
+ weights = weights / weights.sum(dim=1, keepdim=True)
250
+
251
+ # Weighted average of neighbors
252
+ smoothed = (weights * nearest_pts).sum(dim=1)
253
+
254
+ # Blend original point with smoothed version
255
+ alpha = 0.3 # How much smoothing to apply
256
+ return (1 - alpha) * x + alpha * smoothed
257
+
258
+ def get_timepoint_data(self):
259
+ """Return data organized by timepoints for visualization"""
260
+ return {
261
+ 't0': self.coords_t0,
262
+ 't1': self.coords_t1,
263
+ 'time_labels': self.time_labels
264
+ }
265
+
266
+ def get_datamodule():
267
+ datamodule = DrugResponseDataModule(args)
268
+ datamodule.setup(stage="fit")
269
+ return datamodule
dataloaders/clonidine_single_branch.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ sys.argv = ['']
4
+ from sklearn.preprocessing import StandardScaler
5
+ import pytorch_lightning as pl
6
+ from torch.utils.data import DataLoader
7
+ from lightning.pytorch.utilities.combined_loader import CombinedLoader
8
+ import pandas as pd
9
+ import numpy as np
10
+ from functools import partial
11
+ from scipy.spatial import cKDTree
12
+ from sklearn.cluster import KMeans
13
+ from torch.utils.data import TensorDataset
14
+
15
+ #uncomment for plotting
16
+ #from train.parsers_tahoe import parse_args
17
+ #args = parse_args()
18
+
19
+ class ClonidineSingleBranchDataModule(pl.LightningDataModule):
20
+ def __init__(self, args):
21
+ super().__init__()
22
+ self.save_hyperparameters()
23
+
24
+ self.batch_size = args.batch_size
25
+ self.max_dim = args.dim
26
+ self.whiten = args.whiten
27
+ self.split_ratios = args.split_ratios
28
+
29
+ self.dim = args.dim
30
+ print("dimension")
31
+ print(self.dim)
32
+ # Path to your combined data
33
+ self.data_path = "./data/pca_and_leiden_labels.csv"
34
+ self.num_timesteps = 2
35
+ self.args = args
36
+ self._prepare_data()
37
+
38
+ def _prepare_data(self):
39
+ df = pd.read_csv(self.data_path, comment='#')
40
+ df = df.iloc[:, 1:]
41
+ df = df.replace('', np.nan)
42
+ pc_cols = df.columns[:self.dim]
43
+ for col in pc_cols:
44
+ df[col] = pd.to_numeric(df[col], errors='coerce')
45
+ leiden_dmso_col = 'leiden_DMSO_TF_0.0uM'
46
+ leiden_clonidine_col = 'leiden_Clonidine (hydrochloride)_5.0uM'
47
+
48
+ dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column
49
+ clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column
50
+
51
+ dmso_data = df[dmso_mask].copy()
52
+ clonidine_data = df[clonidine_mask].copy()
53
+
54
+ top_clonidine_clusters = ['0.0', '4.0']
55
+
56
+ x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]]
57
+ x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]]
58
+
59
+ x1_1_coords = x1_1_data[pc_cols].values
60
+ x1_2_coords = x1_2_data[pc_cols].values
61
+
62
+ x1_1_coords = x1_1_coords.astype(float)
63
+ x1_2_coords = x1_2_coords.astype(float)
64
+
65
+ # Target size is now the minimum across all three endpoint clusters
66
+ target_size = min(len(x1_1_coords), len(x1_2_coords),)
67
+
68
+ # Helper function to select points closest to centroid
69
+ def select_closest_to_centroid(coords, target_size):
70
+ if len(coords) <= target_size:
71
+ return coords
72
+
73
+ # Calculate centroid
74
+ centroid = np.mean(coords, axis=0)
75
+
76
+ # Calculate distances to centroid
77
+ distances = np.linalg.norm(coords - centroid, axis=1)
78
+
79
+ # Get indices of closest points
80
+ closest_indices = np.argsort(distances)[:target_size]
81
+
82
+ return coords[closest_indices]
83
+
84
+ # Sample all endpoint clusters to target size using centroid-based selection
85
+ x1_1_coords = select_closest_to_centroid(x1_1_coords, target_size)
86
+ x1_2_coords = select_closest_to_centroid(x1_2_coords, target_size)
87
+
88
+ dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts()
89
+
90
+ # DMSO (unchanged)
91
+ largest_dmso_cluster = dmso_cluster_counts.index[0]
92
+ dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster]
93
+
94
+ dmso_coords = dmso_cluster_data[pc_cols].values
95
+
96
+ # Random sampling from largest DMSO cluster to match target size
97
+ # For DMSO, we'll also use centroid-based selection for consistency
98
+ if len(dmso_coords) >= target_size:
99
+ x0_coords = select_closest_to_centroid(dmso_coords, target_size)
100
+ else:
101
+ # If largest cluster is smaller than target, use all of it and pad with other DMSO cells
102
+ remaining_needed = target_size - len(dmso_coords)
103
+ other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster]
104
+ other_dmso_coords = other_dmso_data[pc_cols].values
105
+
106
+ if len(other_dmso_coords) >= remaining_needed:
107
+ # Select closest to centroid from other DMSO cells
108
+ other_selected = select_closest_to_centroid(other_dmso_coords, remaining_needed)
109
+ x0_coords = np.vstack([dmso_coords, other_selected])
110
+ else:
111
+ # Use all available DMSO cells and reduce target size
112
+ all_dmso_coords = dmso_data[pc_cols].values
113
+ target_size = min(target_size, len(all_dmso_coords))
114
+ x0_coords = select_closest_to_centroid(all_dmso_coords, target_size)
115
+
116
+ # Re-select endpoint clusters with updated target size
117
+ x1_1_coords = select_closest_to_centroid(x1_1_data[pc_cols].values.astype(float), target_size)
118
+ x1_2_coords = select_closest_to_centroid(x1_2_data[pc_cols].values.astype(float), target_size)
119
+
120
+ # No need to resample since we already selected the right number
121
+ # The endpoint clusters are already at target_size from centroid-based selection
122
+
123
+ self.n_samples = target_size
124
+
125
+ x0 = torch.tensor(x0_coords, dtype=torch.float32)
126
+ x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
127
+ x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
128
+ x1 = torch.cat([x1_1, x1_2], dim=0)
129
+
130
+ self.coords_t0 = x0
131
+ self.coords_t1 = x1
132
+ self.time_labels = np.concatenate([
133
+ np.zeros(len(self.coords_t0)), # t=0
134
+ np.ones(len(self.coords_t1)), # t=1
135
+ ])
136
+
137
+ split_index = int(target_size * self.split_ratios[0])
138
+
139
+ if target_size - split_index < self.batch_size:
140
+ split_index = target_size - self.batch_size
141
+ print('total count is:', target_size)
142
+
143
+ train_x0 = x0[:split_index]
144
+ val_x0 = x0[split_index:]
145
+ train_x1 = x1[:split_index]
146
+ val_x1 = x1[split_index:]
147
+
148
+
149
+ self.val_x0 = val_x0
150
+
151
+ train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
152
+ train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=1.0)
153
+
154
+ val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
155
+ val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=1.0)
156
+
157
+ # Updated train dataloaders to include x1_3
158
+ self.train_dataloaders = {
159
+ "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
160
+ "x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
161
+ }
162
+
163
+ self.val_dataloaders = {
164
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
165
+ "x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
166
+ }
167
+
168
+ all_coords = df[pc_cols].dropna().values.astype(float)
169
+ self.dataset = torch.tensor(all_coords, dtype=torch.float32)
170
+ self.tree = cKDTree(all_coords)
171
+
172
+ self.test_dataloaders = {
173
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
174
+ "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
175
+ }
176
+
177
+ # Updated metric samples - now using 4 clusters instead of 3
178
+ #km_all = KMeans(n_clusters=4, random_state=42).fit(self.dataset.numpy())
179
+ km_all = KMeans(n_clusters=2, random_state=0).fit(self.dataset.numpy())
180
+
181
+ cluster_labels = km_all.labels_
182
+
183
+ cluster_0_mask = cluster_labels == 0
184
+ cluster_1_mask = cluster_labels == 1
185
+
186
+ samples = self.dataset.cpu().numpy()
187
+
188
+ cluster_0_data = samples[cluster_0_mask]
189
+ cluster_1_data = samples[cluster_1_mask]
190
+
191
+ self.metric_samples_dataloaders = [
192
+ DataLoader(
193
+ torch.tensor(cluster_1_data, dtype=torch.float32),
194
+ batch_size=cluster_1_data.shape[0],
195
+ shuffle=False,
196
+ drop_last=False,
197
+ ),
198
+ DataLoader(
199
+ torch.tensor(cluster_0_data, dtype=torch.float32),
200
+ batch_size=cluster_0_data.shape[0],
201
+ shuffle=False,
202
+ drop_last=False,
203
+ ),
204
+ ]
205
+
206
+ def train_dataloader(self):
207
+ combined_loaders = {
208
+ "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
209
+ "metric_samples": CombinedLoader(
210
+ self.metric_samples_dataloaders, mode="min_size"
211
+ ),
212
+ }
213
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
214
+
215
+ def val_dataloader(self):
216
+ combined_loaders = {
217
+ "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
218
+ "metric_samples": CombinedLoader(
219
+ self.metric_samples_dataloaders, mode="min_size"
220
+ ),
221
+ }
222
+
223
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
224
+
225
+
226
+
227
+ def test_dataloader(self):
228
+ combined_loaders = {
229
+ "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
230
+ "metric_samples": CombinedLoader(
231
+ self.metric_samples_dataloaders, mode="min_size"
232
+ ),
233
+ }
234
+
235
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
236
+
237
+ def get_manifold_proj(self, points):
238
+ """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
239
+ return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
240
+
241
+ @staticmethod
242
+ def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
243
+ """
244
+ Apply local smoothing based on k-nearest neighbors in the full dataset
245
+ This replaces the plane projection for 2D manifold regularization
246
+ """
247
+ points_np = x.detach().cpu().numpy()
248
+ _, idx = tree.query(points_np, k=k)
249
+ nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
250
+
251
+ # Compute weighted average of neighbors
252
+ dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
253
+ weights = torch.exp(-dists / temp)
254
+ weights = weights / weights.sum(dim=1, keepdim=True)
255
+
256
+ # Weighted average of neighbors
257
+ smoothed = (weights * nearest_pts).sum(dim=1)
258
+
259
+ # Blend original point with smoothed version
260
+ alpha = 0.3 # How much smoothing to apply
261
+ return (1 - alpha) * x + alpha * smoothed
262
+
263
+ def get_timepoint_data(self):
264
+ """Return data organized by timepoints for visualization"""
265
+ return {
266
+ 't0': self.coords_t0,
267
+ 't1': self.coords_t1,
268
+ 'time_labels': self.time_labels
269
+ }
270
+
271
+ def get_datamodule():
272
+ datamodule = ClonidineSingleBranchDataModule(args)
273
+ datamodule.setup(stage="fit")
274
+ return datamodule
dataloaders/clonidine_v2_data.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ sys.argv = ['']
4
+ from sklearn.preprocessing import StandardScaler
5
+ import pytorch_lightning as pl
6
+ from torch.utils.data import DataLoader
7
+ from lightning.pytorch.utilities.combined_loader import CombinedLoader
8
+ import pandas as pd
9
+ import numpy as np
10
+ from functools import partial
11
+ from scipy.spatial import cKDTree
12
+ from sklearn.cluster import KMeans
13
+ from torch.utils.data import TensorDataset
14
+
15
+ from train.parsers_tahoe import parse_args
16
+ args = parse_args()
17
+
18
+ class ClonidineV2DataModule(pl.LightningDataModule):
19
+ def __init__(self, args):
20
+ super().__init__()
21
+ self.save_hyperparameters()
22
+
23
+ self.batch_size = args.batch_size
24
+ self.max_dim = args.dim
25
+ self.whiten = args.whiten
26
+ self.split_ratios = args.split_ratios
27
+
28
+ self.dim = args.dim
29
+ print("dimension")
30
+ print(self.dim)
31
+ # Path to your combined data
32
+ self.data_path = "./data/pca_and_leiden_labels.csv"
33
+ self.num_timesteps = 2
34
+ self.args = args
35
+ self._prepare_data()
36
+
37
+ def _prepare_data(self):
38
+ df = pd.read_csv(self.data_path, comment='#')
39
+ df = df.iloc[:, 1:]
40
+ df = df.replace('', np.nan)
41
+ pc_cols = df.columns[:150]
42
+ for col in pc_cols:
43
+ df[col] = pd.to_numeric(df[col], errors='coerce')
44
+ leiden_dmso_col = 'leiden_DMSO_TF_0.0uM'
45
+ leiden_clonidine_col = 'leiden_Clonidine (hydrochloride)_5.0uM'
46
+
47
+ dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column
48
+ clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column
49
+
50
+ dmso_data = df[dmso_mask].copy()
51
+ clonidine_data = df[clonidine_mask].copy()
52
+
53
+ top_clonidine_clusters = ['0.0', '4.0']
54
+
55
+ x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]]
56
+ x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]]
57
+
58
+ x1_1_coords = x1_1_data[pc_cols].values
59
+ x1_2_coords = x1_2_data[pc_cols].values
60
+
61
+ x1_1_coords = x1_1_coords.astype(float)
62
+ x1_2_coords = x1_2_coords.astype(float)
63
+
64
+ # Target size is now the minimum across all three endpoint clusters
65
+ target_size = min(len(x1_1_coords), len(x1_2_coords),)
66
+
67
+ # Helper function to select points closest to centroid
68
+ def select_closest_to_centroid(coords, target_size):
69
+ if len(coords) <= target_size:
70
+ return coords
71
+
72
+ # Calculate centroid
73
+ centroid = np.mean(coords, axis=0)
74
+
75
+ # Calculate distances to centroid
76
+ distances = np.linalg.norm(coords - centroid, axis=1)
77
+
78
+ # Get indices of closest points
79
+ closest_indices = np.argsort(distances)[:target_size]
80
+
81
+ return coords[closest_indices]
82
+
83
+ # Sample all endpoint clusters to target size using centroid-based selection
84
+ x1_1_coords = select_closest_to_centroid(x1_1_coords, target_size)
85
+ x1_2_coords = select_closest_to_centroid(x1_2_coords, target_size)
86
+
87
+ dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts()
88
+
89
+ # DMSO (unchanged)
90
+ largest_dmso_cluster = dmso_cluster_counts.index[0]
91
+ dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster]
92
+
93
+ dmso_coords = dmso_cluster_data[pc_cols].values
94
+
95
+ # Random sampling from largest DMSO cluster to match target size
96
+ # For DMSO, we'll also use centroid-based selection for consistency
97
+ if len(dmso_coords) >= target_size:
98
+ x0_coords = select_closest_to_centroid(dmso_coords, target_size)
99
+ else:
100
+ # If largest cluster is smaller than target, use all of it and pad with other DMSO cells
101
+ remaining_needed = target_size - len(dmso_coords)
102
+ other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster]
103
+ other_dmso_coords = other_dmso_data[pc_cols].values
104
+
105
+ if len(other_dmso_coords) >= remaining_needed:
106
+ # Select closest to centroid from other DMSO cells
107
+ other_selected = select_closest_to_centroid(other_dmso_coords, remaining_needed)
108
+ x0_coords = np.vstack([dmso_coords, other_selected])
109
+ else:
110
+ # Use all available DMSO cells and reduce target size
111
+ all_dmso_coords = dmso_data[pc_cols].values
112
+ target_size = min(target_size, len(all_dmso_coords))
113
+ x0_coords = select_closest_to_centroid(all_dmso_coords, target_size)
114
+
115
+ # Re-select endpoint clusters with updated target size
116
+ x1_1_coords = select_closest_to_centroid(x1_1_data[pc_cols].values.astype(float), target_size)
117
+ x1_2_coords = select_closest_to_centroid(x1_2_data[pc_cols].values.astype(float), target_size)
118
+
119
+ # No need to resample since we already selected the right number
120
+ # The endpoint clusters are already at target_size from centroid-based selection
121
+
122
+ self.n_samples = target_size
123
+
124
+ x0 = torch.tensor(x0_coords, dtype=torch.float32)
125
+ x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
126
+ x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
127
+
128
+ self.coords_t0 = x0
129
+ self.coords_t1_1 = x1_1
130
+ self.coords_t1_2 = x1_2
131
+ self.time_labels = np.concatenate([
132
+ np.zeros(len(self.coords_t0)), # t=0
133
+ np.ones(len(self.coords_t1_1)), # t=1
134
+ np.ones(len(self.coords_t1_2)),
135
+ ])
136
+
137
+ split_index = int(target_size * self.split_ratios[0])
138
+
139
+ if target_size - split_index < self.batch_size:
140
+ split_index = target_size - self.batch_size
141
+ print('total count is:', target_size)
142
+
143
+ train_x0 = x0[:split_index]
144
+ val_x0 = x0[split_index:]
145
+ train_x1_1 = x1_1[:split_index]
146
+ val_x1_1 = x1_1[split_index:]
147
+ train_x1_2 = x1_2[:split_index]
148
+ val_x1_2 = x1_2[split_index:]
149
+
150
+
151
+ self.val_x0 = val_x0
152
+
153
+ train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
154
+ train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.5)
155
+ train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.5)
156
+
157
+ val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
158
+ val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.5)
159
+ val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.5)
160
+
161
+ # Updated train dataloaders to include x1_3
162
+ self.train_dataloaders = {
163
+ "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
164
+ "x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
165
+ "x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
166
+ }
167
+
168
+ self.val_dataloaders = {
169
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
170
+ "x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
171
+ "x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
172
+ }
173
+
174
+ all_coords = df[pc_cols].dropna().values.astype(float)
175
+ self.dataset = torch.tensor(all_coords, dtype=torch.float32)
176
+ self.tree = cKDTree(all_coords)
177
+
178
+ self.test_dataloaders = {
179
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
180
+ "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
181
+ }
182
+
183
+ km_all = KMeans(n_clusters=3, random_state=0).fit(self.dataset.numpy())
184
+
185
+ cluster_labels = km_all.labels_
186
+
187
+ cluster_0_mask = cluster_labels == 0
188
+ cluster_1_mask = cluster_labels == 1
189
+ cluster_2_mask = cluster_labels == 2
190
+
191
+ samples = self.dataset.cpu().numpy()
192
+
193
+ cluster_0_data = samples[cluster_0_mask]
194
+ cluster_1_data = samples[cluster_1_mask]
195
+ cluster_2_data = samples[cluster_2_mask]
196
+
197
+ self.metric_samples_dataloaders = [
198
+ DataLoader(
199
+ torch.tensor(cluster_2_data, dtype=torch.float32),
200
+ batch_size=cluster_2_data.shape[0],
201
+ shuffle=False,
202
+ drop_last=False,
203
+ ),
204
+ DataLoader(
205
+ torch.tensor(cluster_0_data, dtype=torch.float32),
206
+ batch_size=cluster_0_data.shape[0],
207
+ shuffle=False,
208
+ drop_last=False,
209
+ ),
210
+
211
+ DataLoader(
212
+ torch.tensor(cluster_1_data, dtype=torch.float32),
213
+ batch_size=cluster_1_data.shape[0],
214
+ shuffle=False,
215
+ drop_last=False,
216
+ ),
217
+ ]
218
+
219
+ def train_dataloader(self):
220
+ combined_loaders = {
221
+ "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
222
+ "metric_samples": CombinedLoader(
223
+ self.metric_samples_dataloaders, mode="min_size"
224
+ ),
225
+ }
226
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
227
+
228
+ def val_dataloader(self):
229
+ combined_loaders = {
230
+ "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
231
+ "metric_samples": CombinedLoader(
232
+ self.metric_samples_dataloaders, mode="min_size"
233
+ ),
234
+ }
235
+
236
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
237
+
238
+
239
+ def test_dataloader(self):
240
+ combined_loaders = {
241
+ "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
242
+ "metric_samples": CombinedLoader(
243
+ self.metric_samples_dataloaders, mode="min_size"
244
+ ),
245
+ }
246
+
247
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
248
+
249
+ def get_manifold_proj(self, points):
250
+ """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
251
+ return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
252
+
253
+ @staticmethod
254
+ def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
255
+ """
256
+ Apply local smoothing based on k-nearest neighbors in the full dataset
257
+ This replaces the plane projection for 2D manifold regularization
258
+ """
259
+ points_np = x.detach().cpu().numpy()
260
+ _, idx = tree.query(points_np, k=k)
261
+ nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
262
+
263
+ # Compute weighted average of neighbors
264
+ dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
265
+ weights = torch.exp(-dists / temp)
266
+ weights = weights / weights.sum(dim=1, keepdim=True)
267
+
268
+ # Weighted average of neighbors
269
+ smoothed = (weights * nearest_pts).sum(dim=1)
270
+
271
+ # Blend original point with smoothed version
272
+ alpha = 0.3 # How much smoothing to apply
273
+ return (1 - alpha) * x + alpha * smoothed
274
+
275
+ def get_timepoint_data(self):
276
+ """Return data organized by timepoints for visualization"""
277
+ return {
278
+ 't0': self.coords_t0,
279
+ 't1_1': self.coords_t1_1,
280
+ 't1_2': self.coords_t1_2,
281
+ 'time_labels': self.time_labels
282
+ }
283
+
284
+ def get_datamodule():
285
+ datamodule = ClonidineV2DataModule(args)
286
+ datamodule.setup(stage="fit")
287
+ return datamodule
dataloaders/lidar_data.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ sys.argv = ['']
4
+ from sklearn.preprocessing import StandardScaler
5
+ import pytorch_lightning as pl
6
+ from torch.utils.data import DataLoader
7
+ from pytorch_lightning.utilities.combined_loader import CombinedLoader
8
+ import laspy
9
+ import numpy as np
10
+ from scipy.spatial import cKDTree
11
+ import math
12
+ from functools import partial
13
+ from torch.utils.data import TensorDataset
14
+
15
+ #from train.parsers import parse_args
16
+ #args = parse_args()
17
+
18
+ class GaussianMM:
19
+ def __init__(self, mu, var):
20
+ super().__init__()
21
+ self.centers = torch.tensor(mu)
22
+ self.logstd = torch.tensor(var).log() / 2.0
23
+ self.K = self.centers.shape[0]
24
+
25
+ def logprob(self, x):
26
+ logprobs = self.normal_logprob(
27
+ x.unsqueeze(1), self.centers.unsqueeze(0), self.logstd
28
+ )
29
+ logprobs = torch.sum(logprobs, dim=2)
30
+ return torch.logsumexp(logprobs, dim=1) - math.log(self.K)
31
+
32
+ def normal_logprob(self, z, mean, log_std):
33
+ mean = mean + torch.tensor(0.0)
34
+ log_std = log_std + torch.tensor(0.0)
35
+ c = torch.tensor([math.log(2 * math.pi)]).to(z)
36
+ inv_sigma = torch.exp(-log_std)
37
+ tmp = (z - mean) * inv_sigma
38
+ return -0.5 * (tmp * tmp + 2 * log_std + c)
39
+
40
+ def __call__(self, n_samples):
41
+ idx = torch.randint(self.K, (n_samples,)).to(self.centers.device)
42
+ mean = self.centers[idx]
43
+ return torch.randn(*mean.shape).to(mean) * torch.exp(self.logstd) + mean
44
+
45
+ class BranchedLidarDataModule(pl.LightningDataModule):
46
+ def __init__(self, args):
47
+ super().__init__()
48
+ self.save_hyperparameters()
49
+
50
+ self.data_path = args.data_path
51
+ self.batch_size = args.batch_size
52
+ self.max_dim = args.dim
53
+ self.whiten = args.whiten
54
+ self.p0_mu = [
55
+ [-4.5, -4.0, 0.5],
56
+ [-4.2, -3.5, 0.5],
57
+ [-4.0, -3.0, 0.5],
58
+ [-3.75, -2.5, 0.5],
59
+ ]
60
+ self.p0_var = 0.02
61
+
62
+ self.p1_1_mu = [
63
+ [-2.5, -0.25, 0.5],
64
+ [-2.25, 0.675, 0.5],
65
+ [-2, 1.5, 0.5],
66
+ ]
67
+ self.p1_2_mu = [
68
+ [2, -2, 0.5],
69
+ [2.6, -1.25, 0.5],
70
+ [3.2, -0.5, 0.5]
71
+ ]
72
+
73
+ self.p1_var = 0.03
74
+ self.k = 20
75
+ self.n_samples = 5000
76
+ self.num_timesteps = 2
77
+ self.split_ratios = args.split_ratios
78
+ self._prepare_data()
79
+
80
+ def assign_region(self):
81
+ all_centers = {
82
+ 0: torch.tensor(self.p0_mu), # Region 0: p0
83
+ 1: torch.tensor(self.p1_1_mu), # Region 1: p1_1
84
+ 2: torch.tensor(self.p1_2_mu), # Region 2: p1_2
85
+ }
86
+
87
+ dataset = self.dataset.to(torch.float32)
88
+ N = dataset.shape[0]
89
+ assignments = torch.zeros(N, dtype=torch.long)
90
+
91
+ # For each point, compute min distance to each region's centers
92
+ for i in range(N):
93
+ point = dataset[i]
94
+ min_dist = float("inf")
95
+ best_region = 0
96
+ for region, centers in all_centers.items():
97
+ dists = ((centers - point)**2).sum(dim=1)
98
+ region_min = dists.min()
99
+ if region_min < min_dist:
100
+ min_dist = region_min
101
+ best_region = region
102
+ assignments[i] = best_region
103
+ return assignments
104
+
105
+ def _prepare_data(self):
106
+ las = laspy.read(self.data_path)
107
+ # Extract only "ground" points.
108
+ self.mask = las.classification == 2
109
+ # Original Preprocessing
110
+ x_offset, x_scale = las.header.offsets[0], las.header.scales[0]
111
+ y_offset, y_scale = las.header.offsets[1], las.header.scales[1]
112
+ z_offset, z_scale = las.header.offsets[2], las.header.scales[2]
113
+ dataset = np.vstack(
114
+ (
115
+ las.X[self.mask] * x_scale + x_offset,
116
+ las.Y[self.mask] * y_scale + y_offset,
117
+ las.Z[self.mask] * z_scale + z_offset,
118
+ )
119
+ ).transpose()
120
+ mi = dataset.min(axis=0, keepdims=True)
121
+ ma = dataset.max(axis=0, keepdims=True)
122
+ dataset = (dataset - mi) / (ma - mi) * [10.0, 10.0, 2.0] + [-5.0, -5.0, 0.0]
123
+
124
+ self.dataset = torch.tensor(dataset, dtype=torch.float32)
125
+ self.tree = cKDTree(dataset)
126
+
127
+ x0_gaussian = GaussianMM(self.p0_mu, self.p0_var)(self.n_samples)
128
+ x1_1_gaussian = GaussianMM(self.p1_1_mu, self.p1_var)(self.n_samples)
129
+ x1_2_gaussian = GaussianMM(self.p1_2_mu, self.p1_var)(self.n_samples)
130
+
131
+ x0 = self.get_tangent_proj(x0_gaussian)(x0_gaussian)
132
+ x1_1 = self.get_tangent_proj(x1_1_gaussian)(x1_1_gaussian)
133
+ x1_2 = self.get_tangent_proj(x1_2_gaussian)(x1_2_gaussian)
134
+
135
+ split_index = int(self.n_samples * self.split_ratios[0])
136
+
137
+ self.scaler = StandardScaler()
138
+ if self.whiten:
139
+ self.dataset = torch.tensor(
140
+ self.scaler.fit_transform(dataset), dtype=torch.float32
141
+ )
142
+ x0 = torch.tensor(self.scaler.transform(x0), dtype=torch.float32)
143
+ x1_1 = torch.tensor(self.scaler.transform(x1_1), dtype=torch.float32)
144
+ x1_2 = torch.tensor(self.scaler.transform(x1_2), dtype=torch.float32)
145
+
146
+ train_x0 = x0[:split_index]
147
+ val_x0 = x0[split_index:]
148
+
149
+ # branches
150
+ train_x1_1 = x1_1[:split_index]
151
+ print("train_x1_1")
152
+ print(train_x1_1.shape)
153
+ val_x1_1 = x1_1[split_index:]
154
+ train_x1_2 = x1_2[:split_index]
155
+ val_x1_2 = x1_2[split_index:]
156
+
157
+ self.val_x0 = val_x0
158
+
159
+ # Adjust split_index to ensure minimum validation samples
160
+ if self.n_samples - split_index < self.batch_size:
161
+ split_index = self.n_samples - self.batch_size
162
+
163
+ self.train_dataloaders = {
164
+ "x0": DataLoader(train_x0, batch_size=self.batch_size, shuffle=True, drop_last=True),
165
+ "x1_1": DataLoader(train_x1_1, batch_size=self.batch_size, shuffle=True, drop_last=True),
166
+ "x1_2": DataLoader(train_x1_2, batch_size=self.batch_size, shuffle=True, drop_last=True),
167
+ }
168
+ self.val_dataloaders = {
169
+ "x0": DataLoader(val_x0, batch_size=self.batch_size, shuffle=False, drop_last=True),
170
+ "x1_1": DataLoader(val_x1_1, batch_size=self.batch_size, shuffle=True, drop_last=True),
171
+ "x1_2": DataLoader(val_x1_2, batch_size=self.batch_size, shuffle=True, drop_last=True),
172
+ }
173
+ # to edit?
174
+ self.test_dataloaders = [
175
+ DataLoader(
176
+ self.val_x0,
177
+ batch_size=self.val_x0.shape[0],
178
+ shuffle=False,
179
+ drop_last=False,
180
+ ),
181
+ DataLoader(
182
+ self.dataset,
183
+ batch_size=self.dataset.shape[0],
184
+ shuffle=False,
185
+ drop_last=False,
186
+ ),
187
+ ]
188
+
189
+ points = self.dataset.cpu().numpy()
190
+ x, y = points[:, 0], points[:, 1]
191
+ # Diagonal-based coordinates (rotated 45°)
192
+ u = (x + y) / np.sqrt(2) # along x=y
193
+ # start region (A) using u
194
+ u_thresh = np.percentile(u, 30) # tweak this threshold to control size
195
+ mask_A = u <= u_thresh
196
+
197
+ # among the rest, split by x=y diagonal
198
+ remaining = ~mask_A
199
+ mask_B = remaining & (x < y) # left of diagonal
200
+ mask_C = remaining & (x >= y) # right of diagonal
201
+
202
+ # Assign dataloaders
203
+ self.metric_samples_dataloaders = [
204
+ DataLoader(torch.tensor(points[mask_A], dtype=torch.float32), batch_size=points[mask_A].shape[0], shuffle=False),
205
+ DataLoader(torch.tensor(points[mask_B], dtype=torch.float32), batch_size=points[mask_B].shape[0], shuffle=False),
206
+ DataLoader(torch.tensor(points[mask_C], dtype=torch.float32), batch_size=points[mask_C].shape[0], shuffle=False),
207
+ ]
208
+
209
+ def train_dataloader(self):
210
+ combined_loaders = {
211
+ "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
212
+ "metric_samples": CombinedLoader(
213
+ self.metric_samples_dataloaders, mode="min_size"
214
+ ),
215
+ }
216
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
217
+
218
+ def val_dataloader(self):
219
+ combined_loaders = {
220
+ "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
221
+ "metric_samples": CombinedLoader(
222
+ self.metric_samples_dataloaders, mode="min_size"
223
+ ),
224
+ }
225
+
226
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
227
+
228
+ def test_dataloader(self):
229
+ return CombinedLoader(self.test_dataloaders)
230
+
231
+ def get_tangent_proj(self, points):
232
+ w = self.get_tangent_plane(points)
233
+ return partial(BranchedLidarDataModule.projection_op, w=w)
234
+
235
+ def get_tangent_plane(self, points, temp=1e-3):
236
+ points_np = points.detach().cpu().numpy()
237
+ _, idx = self.tree.query(points_np, k=self.k)
238
+ nearest_pts = self.dataset[idx]
239
+ nearest_pts = torch.tensor(nearest_pts).to(points)
240
+
241
+ dists = (points.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
242
+ weights = torch.exp(-dists / temp)
243
+
244
+ # Fits plane with least vertical distance.
245
+ w = BranchedLidarDataModule.fit_plane(nearest_pts, weights)
246
+ return w
247
+
248
+ @staticmethod
249
+ def fit_plane(points, weights=None):
250
+ """Expects points to be of shape (..., 3).
251
+ Returns [a, b, c] such that the plane is defined as
252
+ ax + by + c = z
253
+ """
254
+ D = torch.cat([points[..., :2], torch.ones_like(points[..., 2:3])], dim=-1)
255
+ z = points[..., 2]
256
+ if weights is not None:
257
+ Dtrans = D.transpose(-1, -2)
258
+ else:
259
+ DW = D * weights
260
+ Dtrans = DW.transpose(-1, -2)
261
+ w = torch.linalg.solve(
262
+ torch.matmul(Dtrans, D), torch.matmul(Dtrans, z.unsqueeze(-1))
263
+ ).squeeze(-1)
264
+ return w
265
+
266
+ @staticmethod
267
+ def projection_op(x, w):
268
+ """Projects points to a plane defined by w."""
269
+ # Normal vector to the tangent plane.
270
+ n = torch.cat([w[..., :2], -torch.ones_like(w[..., 2:3])], dim=1)
271
+
272
+ pn = torch.sum(x * n, dim=-1, keepdim=True)
273
+ nn = torch.sum(n * n, dim=-1, keepdim=True)
274
+
275
+ # Offset.
276
+ d = w[..., 2:3]
277
+
278
+ # Projection of x onto n.
279
+ projn_x = ((pn + d) / nn) * n
280
+
281
+ # Remove component in the normal direction.
282
+ return x - projn_x
283
+
284
+ class WeightedBranchedLidarDataModule(pl.LightningDataModule):
285
+ def __init__(self, args):
286
+ super().__init__()
287
+ self.save_hyperparameters()
288
+
289
+ self.data_path = args.data_path
290
+ self.batch_size = args.batch_size
291
+ self.max_dim = args.dim
292
+ self.whiten = args.whiten
293
+ self.p0_mu = [
294
+ [-4.5, -4.0, 0.5],
295
+ [-4.2, -3.5, 0.5],
296
+ [-4.0, -3.0, 0.5],
297
+ [-3.75, -2.5, 0.5],
298
+ ]
299
+ self.p0_var = 0.02
300
+ # multiple p1 for each branch
301
+ #changed
302
+ self.p1_1_mu = [
303
+ [-2.5, -0.25, 0.5],
304
+ [-2.25, 0.675, 0.5],
305
+ [-2, 1.5, 0.5],
306
+ ]
307
+ self.p1_2_mu = [
308
+ [2, -2, 0.5],
309
+ [2.6, -1.25, 0.5],
310
+ [3.2, -0.5, 0.5]
311
+ ]
312
+
313
+ self.p1_var = 0.03
314
+ self.k = 20
315
+ self.n_samples = 5000
316
+ self.num_timesteps = 2
317
+ self.split_ratios = args.split_ratios
318
+
319
+ self.num_timesteps = 2
320
+ self.metric_clusters = 3
321
+ self.args = args
322
+ self._prepare_data()
323
+
324
+ def _prepare_data(self):
325
+ las = laspy.read(self.data_path)
326
+ # Extract only "ground" points.
327
+ self.mask = las.classification == 2
328
+ # Original Preprocessing
329
+ x_offset, x_scale = las.header.offsets[0], las.header.scales[0]
330
+ y_offset, y_scale = las.header.offsets[1], las.header.scales[1]
331
+ z_offset, z_scale = las.header.offsets[2], las.header.scales[2]
332
+ dataset = np.vstack(
333
+ (
334
+ las.X[self.mask] * x_scale + x_offset,
335
+ las.Y[self.mask] * y_scale + y_offset,
336
+ las.Z[self.mask] * z_scale + z_offset,
337
+ )
338
+ ).transpose()
339
+ mi = dataset.min(axis=0, keepdims=True)
340
+ ma = dataset.max(axis=0, keepdims=True)
341
+ dataset = (dataset - mi) / (ma - mi) * [10.0, 10.0, 2.0] + [-5.0, -5.0, 0.0]
342
+
343
+ self.dataset = torch.tensor(dataset, dtype=torch.float32)
344
+ self.tree = cKDTree(dataset)
345
+
346
+ x0_gaussian = GaussianMM(self.p0_mu, self.p0_var)(self.n_samples)
347
+ x1_1_gaussian = GaussianMM(self.p1_1_mu, self.p1_var)(self.n_samples)
348
+ x1_2_gaussian = GaussianMM(self.p1_2_mu, self.p1_var)(self.n_samples)
349
+
350
+ x0 = self.get_tangent_proj(x0_gaussian)(x0_gaussian)
351
+ x1_1 = self.get_tangent_proj(x1_1_gaussian)(x1_1_gaussian)
352
+ x1_2 = self.get_tangent_proj(x1_2_gaussian)(x1_2_gaussian)
353
+
354
+ split_index = int(self.n_samples * self.split_ratios[0])
355
+
356
+ self.scaler = StandardScaler()
357
+ if self.whiten:
358
+ self.dataset = torch.tensor(
359
+ self.scaler.fit_transform(dataset), dtype=torch.float32
360
+ )
361
+ x0 = torch.tensor(self.scaler.transform(x0), dtype=torch.float32)
362
+ x1_1 = torch.tensor(self.scaler.transform(x1_1), dtype=torch.float32)
363
+ x1_2 = torch.tensor(self.scaler.transform(x1_2), dtype=torch.float32)
364
+
365
+ self.coords_t0 = x0
366
+ self.coords_t1_1 = x1_1
367
+ self.coords_t1_2 = x1_2
368
+ self.time_labels = np.concatenate([
369
+ np.zeros(len(self.coords_t0)), # t=0
370
+ np.ones(len(self.coords_t1_1)), # t=1
371
+ np.ones(len(self.coords_t1_2)), # t=1
372
+ ])
373
+
374
+ train_x0 = x0[:split_index]
375
+ val_x0 = x0[split_index:]
376
+
377
+ # branches
378
+ train_x1_1 = x1_1[:split_index]
379
+
380
+ val_x1_1 = x1_1[split_index:]
381
+ train_x1_2 = x1_2[:split_index]
382
+ val_x1_2 = x1_2[split_index:]
383
+
384
+ self.val_x0 = val_x0
385
+
386
+ train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
387
+ train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.5)
388
+ train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.5)
389
+
390
+ val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
391
+ val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.5)
392
+ val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.5)
393
+
394
+ # Adjust split_index to ensure minimum validation samples
395
+ if self.n_samples - split_index < self.batch_size:
396
+ split_index = self.n_samples - self.batch_size
397
+
398
+ self.train_dataloaders = {
399
+ "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
400
+ "x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
401
+ "x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
402
+ }
403
+
404
+ self.val_dataloaders = {
405
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
406
+ "x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
407
+ "x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
408
+ }
409
+
410
+ # to edit?
411
+ self.test_dataloaders = {
412
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
413
+ "x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=True),
414
+ "x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=True),
415
+ "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
416
+ }
417
+
418
+ points = self.dataset.cpu().numpy()
419
+ x, y = points[:, 0], points[:, 1]
420
+ # Diagonal-based coordinates (rotated 45°)
421
+ u = (x + y) / np.sqrt(2) # along x=y
422
+ # start region (A) using u
423
+ u_thresh = np.percentile(u, 30) # tweak this threshold to control size
424
+ mask_A = u <= u_thresh
425
+
426
+ # among the rest, split by x=y diagonal
427
+ remaining = ~mask_A
428
+ mask_B = remaining & (x < y) # left of diagonal
429
+ mask_C = remaining & (x >= y) # right of diagonal
430
+
431
+ # Assign dataloaders
432
+ self.metric_samples_dataloaders = [
433
+ DataLoader(torch.tensor(points[mask_A], dtype=torch.float32), batch_size=points[mask_A].shape[0], shuffle=False),
434
+ DataLoader(torch.tensor(points[mask_B], dtype=torch.float32), batch_size=points[mask_B].shape[0], shuffle=False),
435
+ DataLoader(torch.tensor(points[mask_C], dtype=torch.float32), batch_size=points[mask_C].shape[0], shuffle=False),
436
+ ]
437
+
438
+ def train_dataloader(self):
439
+ combined_loaders = {
440
+ "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
441
+ "metric_samples": CombinedLoader(
442
+ self.metric_samples_dataloaders, mode="min_size"
443
+ ),
444
+ }
445
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
446
+
447
+ def val_dataloader(self):
448
+ combined_loaders = {
449
+ "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
450
+ "metric_samples": CombinedLoader(
451
+ self.metric_samples_dataloaders, mode="min_size"
452
+ ),
453
+ }
454
+
455
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
456
+
457
+ def test_dataloader(self):
458
+ combined_loaders = {
459
+ "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
460
+ "metric_samples": CombinedLoader(
461
+ self.metric_samples_dataloaders, mode="min_size"
462
+ ),
463
+ }
464
+
465
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
466
+
467
+ def get_tangent_proj(self, points):
468
+ w = self.get_tangent_plane(points)
469
+ return partial(BranchedLidarDataModule.projection_op, w=w)
470
+
471
+ def get_tangent_plane(self, points, temp=1e-3):
472
+ points_np = points.detach().cpu().numpy()
473
+ _, idx = self.tree.query(points_np, k=self.k)
474
+ nearest_pts = self.dataset[idx]
475
+ nearest_pts = torch.tensor(nearest_pts).to(points)
476
+
477
+ dists = (points.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
478
+ weights = torch.exp(-dists / temp)
479
+
480
+ # Fits plane with least vertical distance.
481
+ w = BranchedLidarDataModule.fit_plane(nearest_pts, weights)
482
+ return w
483
+
484
+ @staticmethod
485
+ def fit_plane(points, weights=None):
486
+ """Expects points to be of shape (..., 3).
487
+ Returns [a, b, c] such that the plane is defined as
488
+ ax + by + c = z
489
+ """
490
+ D = torch.cat([points[..., :2], torch.ones_like(points[..., 2:3])], dim=-1)
491
+ z = points[..., 2]
492
+ if weights is not None:
493
+ Dtrans = D.transpose(-1, -2)
494
+ else:
495
+ DW = D * weights
496
+ Dtrans = DW.transpose(-1, -2)
497
+ w = torch.linalg.solve(
498
+ torch.matmul(Dtrans, D), torch.matmul(Dtrans, z.unsqueeze(-1))
499
+ ).squeeze(-1)
500
+ return w
501
+
502
+ @staticmethod
503
+ def projection_op(x, w):
504
+ """Projects points to a plane defined by w."""
505
+ # Normal vector to the tangent plane.
506
+ n = torch.cat([w[..., :2], -torch.ones_like(w[..., 2:3])], dim=1)
507
+
508
+ pn = torch.sum(x * n, dim=-1, keepdim=True)
509
+ nn = torch.sum(n * n, dim=-1, keepdim=True)
510
+
511
+ # Offset.
512
+ d = w[..., 2:3]
513
+
514
+ # Projection of x onto n.
515
+ projn_x = ((pn + d) / nn) * n
516
+
517
+ # Remove component in the normal direction.
518
+ return x - projn_x
519
+
520
+ def get_timepoint_data(self):
521
+ """Return data organized by timepoints for visualization"""
522
+ return {
523
+ 't0': self.coords_t0,
524
+ 't1_1': self.coords_t1_1,
525
+ 't1_2': self.coords_t1_2,
526
+ 'time_labels': self.time_labels
527
+ }
528
+
529
+ def get_datamodule():
530
+ datamodule = WeightedBranchedLidarDataModule(args)
531
+ datamodule.setup(stage="fit")
532
+ return datamodule
dataloaders/lidar_data_single.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ sys.argv = ['']
4
+ from sklearn.preprocessing import StandardScaler
5
+ import pytorch_lightning as pl
6
+ from torch.utils.data import DataLoader
7
+ from pytorch_lightning.utilities.combined_loader import CombinedLoader
8
+ import laspy
9
+ import numpy as np
10
+ from scipy.spatial import cKDTree
11
+ import math
12
+ from functools import partial
13
+ from torch.utils.data import TensorDataset
14
+
15
+ from train.parsers import parse_args
16
+ args = parse_args()
17
+
18
+ class GaussianMM:
19
+ def __init__(self, mu, var):
20
+ super().__init__()
21
+ self.centers = torch.tensor(mu)
22
+ self.logstd = torch.tensor(var).log() / 2.0
23
+ self.K = self.centers.shape[0]
24
+
25
+ def logprob(self, x):
26
+ logprobs = self.normal_logprob(
27
+ x.unsqueeze(1), self.centers.unsqueeze(0), self.logstd
28
+ )
29
+ logprobs = torch.sum(logprobs, dim=2)
30
+ return torch.logsumexp(logprobs, dim=1) - math.log(self.K)
31
+
32
+ def normal_logprob(self, z, mean, log_std):
33
+ mean = mean + torch.tensor(0.0)
34
+ log_std = log_std + torch.tensor(0.0)
35
+ c = torch.tensor([math.log(2 * math.pi)]).to(z)
36
+ inv_sigma = torch.exp(-log_std)
37
+ tmp = (z - mean) * inv_sigma
38
+ return -0.5 * (tmp * tmp + 2 * log_std + c)
39
+
40
+ def __call__(self, n_samples):
41
+ idx = torch.randint(self.K, (n_samples,)).to(self.centers.device)
42
+ mean = self.centers[idx]
43
+ return torch.randn(*mean.shape).to(mean) * torch.exp(self.logstd) + mean
44
+
45
+ class LidarSingleDataModule(pl.LightningDataModule):
46
+ def __init__(self, args):
47
+ super().__init__()
48
+ self.save_hyperparameters()
49
+
50
+ self.data_path = args.data_path
51
+ self.batch_size = args.batch_size
52
+ self.max_dim = args.dim
53
+ self.whiten = args.whiten
54
+ self.p0_mu = [
55
+ [-4.5, -4.0, 0.5],
56
+ [-4.2, -3.5, 0.5],
57
+ [-4.0, -3.0, 0.5],
58
+ [-3.75, -2.5, 0.5],
59
+ ]
60
+ self.p0_var = 0.02
61
+ # multiple p1 for each branch
62
+ #changed
63
+ self.p1_1_mu = [
64
+ [-2.5, -0.25, 0.5],
65
+ [-2.25, 0.675, 0.5],
66
+ [-2, 1.5, 0.5],
67
+ ]
68
+ self.p1_2_mu = [
69
+ [2, -2, 0.5],
70
+ [2.6, -1.25, 0.5],
71
+ [3.2, -0.5, 0.5]
72
+ ]
73
+
74
+ self.p1_var = 0.03
75
+ self.k = 20
76
+ self.n_samples = 5000
77
+ self.num_timesteps = 2
78
+ self.split_ratios = args.split_ratios
79
+
80
+ self.num_timesteps = 2
81
+ self.metric_clusters = 3
82
+ self.args = args
83
+ self._prepare_data()
84
+
85
+ def _prepare_data(self):
86
+ las = laspy.read(self.data_path)
87
+ # Extract only "ground" points.
88
+ self.mask = las.classification == 2
89
+ # Original Preprocessing
90
+ x_offset, x_scale = las.header.offsets[0], las.header.scales[0]
91
+ y_offset, y_scale = las.header.offsets[1], las.header.scales[1]
92
+ z_offset, z_scale = las.header.offsets[2], las.header.scales[2]
93
+ dataset = np.vstack(
94
+ (
95
+ las.X[self.mask] * x_scale + x_offset,
96
+ las.Y[self.mask] * y_scale + y_offset,
97
+ las.Z[self.mask] * z_scale + z_offset,
98
+ )
99
+ ).transpose()
100
+ mi = dataset.min(axis=0, keepdims=True)
101
+ ma = dataset.max(axis=0, keepdims=True)
102
+ dataset = (dataset - mi) / (ma - mi) * [10.0, 10.0, 2.0] + [-5.0, -5.0, 0.0]
103
+
104
+ self.dataset = torch.tensor(dataset, dtype=torch.float32)
105
+ self.tree = cKDTree(dataset)
106
+
107
+ x0_gaussian = GaussianMM(self.p0_mu, self.p0_var)(self.n_samples)
108
+ x1_1_gaussian = GaussianMM(self.p1_1_mu, self.p1_var)(self.n_samples)
109
+ x1_2_gaussian = GaussianMM(self.p1_2_mu, self.p1_var)(self.n_samples)
110
+
111
+ x0 = self.get_tangent_proj(x0_gaussian)(x0_gaussian)
112
+ x1_1 = self.get_tangent_proj(x1_1_gaussian)(x1_1_gaussian)
113
+ x1_2 = self.get_tangent_proj(x1_2_gaussian)(x1_2_gaussian)
114
+
115
+ split_index = int(self.n_samples * self.split_ratios[0])
116
+
117
+ self.scaler = StandardScaler()
118
+ if self.whiten:
119
+ self.dataset = torch.tensor(
120
+ self.scaler.fit_transform(dataset), dtype=torch.float32
121
+ )
122
+ x0 = torch.tensor(self.scaler.transform(x0), dtype=torch.float32)
123
+ x1_1 = torch.tensor(self.scaler.transform(x1_1), dtype=torch.float32)
124
+ x1_2 = torch.tensor(self.scaler.transform(x1_2), dtype=torch.float32)
125
+ x1 = torch.cat([x1_1, x1_2], dim=0)
126
+
127
+ self.coords_t0 = x0
128
+ self.coords_t1 = x1
129
+ self.time_labels = np.concatenate([
130
+ np.zeros(len(self.coords_t0)), # t=0
131
+ np.ones(len(self.coords_t1)), # t=1
132
+ ])
133
+
134
+ train_x0 = x0[:split_index]
135
+ val_x0 = x0[split_index:]
136
+
137
+ # branches
138
+ train_x1 = x1[:split_index]
139
+ val_x1 = x1[split_index:]
140
+
141
+ self.val_x0 = val_x0
142
+
143
+ train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
144
+ train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=1.0)
145
+
146
+ val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
147
+ val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=1.0)
148
+
149
+ # Adjust split_index to ensure minimum validation samples
150
+ if self.n_samples - split_index < self.batch_size:
151
+ split_index = self.n_samples - self.batch_size
152
+
153
+ self.train_dataloaders = {
154
+ "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
155
+ "x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
156
+ }
157
+
158
+ self.val_dataloaders = {
159
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
160
+ "x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
161
+ }
162
+
163
+ # to edit?
164
+ self.test_dataloaders = {
165
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=False),
166
+ "x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=True),
167
+ "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=True, drop_last=False),
168
+ }
169
+
170
+ points = self.dataset.cpu().numpy()
171
+ x, y = points[:, 0], points[:, 1]
172
+ # Diagonal-based coordinates (rotated 45°)
173
+ u = (x + y) / np.sqrt(2) # along x=y
174
+ # start region (A) using u
175
+ u_thresh = np.percentile(u, 30) # tweak this threshold to control size
176
+ mask_A = u <= u_thresh
177
+
178
+ # among the rest, split by x=y diagonal
179
+ remaining = ~mask_A
180
+ mask_B = remaining & (x < y) # left of diagonal
181
+ mask_C = remaining & (x >= y) # right of diagonal
182
+
183
+ # Assign dataloaders
184
+ self.metric_samples_dataloaders = [
185
+ DataLoader(torch.tensor(points[mask_A], dtype=torch.float32), batch_size=points[mask_A].shape[0], shuffle=False),
186
+ DataLoader(torch.tensor(points[remaining], dtype=torch.float32), batch_size=points[remaining].shape[0], shuffle=False),
187
+ ]
188
+
189
+ def train_dataloader(self):
190
+ combined_loaders = {
191
+ "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
192
+ "metric_samples": CombinedLoader(
193
+ self.metric_samples_dataloaders, mode="min_size"
194
+ ),
195
+ }
196
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
197
+
198
+ def val_dataloader(self):
199
+ combined_loaders = {
200
+ "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
201
+ "metric_samples": CombinedLoader(
202
+ self.metric_samples_dataloaders, mode="min_size"
203
+ ),
204
+ }
205
+
206
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
207
+
208
+ def test_dataloader(self):
209
+ combined_loaders = {
210
+ "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
211
+ "metric_samples": CombinedLoader(
212
+ self.metric_samples_dataloaders, mode="min_size"
213
+ ),
214
+ }
215
+
216
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
217
+
218
+ def get_tangent_proj(self, points):
219
+ w = self.get_tangent_plane(points)
220
+ return partial(LidarSingleDataModule.projection_op, w=w)
221
+
222
+ def get_tangent_plane(self, points, temp=1e-3):
223
+ points_np = points.detach().cpu().numpy()
224
+ _, idx = self.tree.query(points_np, k=self.k)
225
+ nearest_pts = self.dataset[idx]
226
+ nearest_pts = torch.tensor(nearest_pts).to(points)
227
+
228
+ dists = (points.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
229
+ weights = torch.exp(-dists / temp)
230
+
231
+ # Fits plane with least vertical distance.
232
+ w = LidarSingleDataModule.fit_plane(nearest_pts, weights)
233
+ return w
234
+
235
+ @staticmethod
236
+ def fit_plane(points, weights=None):
237
+ """Expects points to be of shape (..., 3).
238
+ Returns [a, b, c] such that the plane is defined as
239
+ ax + by + c = z
240
+ """
241
+ D = torch.cat([points[..., :2], torch.ones_like(points[..., 2:3])], dim=-1)
242
+ z = points[..., 2]
243
+ if weights is not None:
244
+ Dtrans = D.transpose(-1, -2)
245
+ else:
246
+ DW = D * weights
247
+ Dtrans = DW.transpose(-1, -2)
248
+ w = torch.linalg.solve(
249
+ torch.matmul(Dtrans, D), torch.matmul(Dtrans, z.unsqueeze(-1))
250
+ ).squeeze(-1)
251
+ return w
252
+
253
+ @staticmethod
254
+ def projection_op(x, w):
255
+ """Projects points to a plane defined by w."""
256
+ # Normal vector to the tangent plane.
257
+ n = torch.cat([w[..., :2], -torch.ones_like(w[..., 2:3])], dim=1)
258
+
259
+ pn = torch.sum(x * n, dim=-1, keepdim=True)
260
+ nn = torch.sum(n * n, dim=-1, keepdim=True)
261
+
262
+ # Offset.
263
+ d = w[..., 2:3]
264
+
265
+ # Projection of x onto n.
266
+ projn_x = ((pn + d) / nn) * n
267
+
268
+ # Remove component in the normal direction.
269
+ return x - projn_x
270
+
271
+ def get_timepoint_data(self):
272
+ """Return data organized by timepoints for visualization"""
273
+ return {
274
+ 't0': self.coords_t0,
275
+ 't1': self.coords_t1,
276
+ 'time_labels': self.time_labels
277
+ }
278
+
279
+ def get_datamodule():
280
+ datamodule = LidarSingleDataModule(args)
281
+ datamodule.setup(stage="fit")
282
+ return datamodule
dataloaders/mouse_data.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ sys.argv = ['']
4
+ from sklearn.preprocessing import StandardScaler
5
+ import pytorch_lightning as pl
6
+ from torch.utils.data import DataLoader
7
+ from lightning.pytorch.utilities.combined_loader import CombinedLoader
8
+ import numpy as np
9
+ from scipy.spatial import cKDTree
10
+ import math
11
+ from functools import partial
12
+ from sklearn.cluster import KMeans, DBSCAN
13
+ import matplotlib.pyplot as plt
14
+ import pandas as pd
15
+ from torch.utils.data import TensorDataset
16
+
17
+ from train.parsers_sc import parse_args
18
+ args = parse_args()
19
+
20
+ class WeightedBranchedCellDataModule(pl.LightningDataModule):
21
+ def __init__(self, args):
22
+ super().__init__()
23
+ self.save_hyperparameters()
24
+
25
+ self.data_path = "./data/mouse_hematopoiesis.csv"
26
+ self.batch_size = args.batch_size
27
+ self.max_dim = args.dim
28
+ self.whiten = args.whiten
29
+ self.k = 20
30
+ self.n_samples = 1429
31
+ self.num_timesteps = 3 # t=0, t=1, t=2
32
+ self.split_ratios = args.split_ratios
33
+ self.metric_clusters = args.metric_clusters
34
+ self.args = args
35
+ self._prepare_data()
36
+
37
+
38
+ def _prepare_data(self):
39
+ print("Preparing cell data in BranchedCellDataModule")
40
+
41
+ df = pd.read_csv(self.data_path)
42
+
43
+ # Build dictionary of coordinates by time
44
+ coords_by_t = {
45
+ t: df[df["samples"] == t][["x1","x2"]].values
46
+ for t in sorted(df["samples"].unique())
47
+ }
48
+ n0 = coords_by_t[0].shape[0] # Number of T=0 points
49
+ self.n_samples = n0 # Update n_samples to match actual data if changes
50
+
51
+ # Cluster the t=2 cells into two branches
52
+ km = KMeans(n_clusters=2, random_state=42).fit(coords_by_t[2])
53
+ df2 = df[df["samples"] == 2].copy()
54
+ df2["branch"] = km.labels_
55
+
56
+ cluster_counts = df2["branch"].value_counts().sort_index()
57
+ print(cluster_counts)
58
+
59
+ # Sample n0 points from each branch
60
+ endpoints = {}
61
+ for b in (0, 1):
62
+ endpoints[b] = (
63
+ df2[df2["branch"] == b]
64
+ .sample(n=n0, random_state=42)[["x1","x2"]]
65
+ .values
66
+ )
67
+
68
+ x0 = torch.tensor(coords_by_t[0], dtype=torch.float32) # T=0 coordinates index
69
+ x_inter = torch.tensor(coords_by_t[1], dtype=torch.float32)
70
+ x1_1 = torch.tensor(endpoints[0], dtype=torch.float32) # Branch index
71
+ x1_2 = torch.tensor(endpoints[1], dtype=torch.float32) # Branch index
72
+
73
+ self.coords_t0 = x0
74
+ self.coords_t1 = x_inter
75
+ self.coords_t2_1 = x1_1
76
+ self.coords_t2_2 = x1_2
77
+ self.time_labels = np.concatenate([
78
+ np.zeros(len(self.coords_t0)), # t=0
79
+ np.ones(len(self.coords_t1)), # t=1
80
+ np.ones(len(self.coords_t2_1)) * 2, # t=1
81
+ np.ones(len(self.coords_t2_2)) * 2,
82
+ ])
83
+
84
+ split_index = int(n0 * self.split_ratios[0])
85
+
86
+ if n0 - split_index < self.batch_size:
87
+ split_index = n0 - self.batch_size
88
+
89
+ train_x0 = x0[:split_index]
90
+ val_x0 = x0[split_index:]
91
+ train_x1_1 = x1_1[:split_index]
92
+ val_x1_1 = x1_1[split_index:]
93
+ train_x1_2 = x1_2[:split_index]
94
+ val_x1_2 = x1_2[split_index:]
95
+
96
+ self.val_x0 = val_x0
97
+
98
+ train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
99
+ train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.5)
100
+ train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.5)
101
+
102
+ val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
103
+ val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.5)
104
+ val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.5)
105
+
106
+ if self.n_samples - split_index < self.batch_size:
107
+ split_index = self.n_samples - self.batch_size
108
+
109
+ self.train_dataloaders = {
110
+ "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
111
+ "x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
112
+ "x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
113
+ }
114
+
115
+ self.val_dataloaders = {
116
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
117
+ "x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
118
+ "x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
119
+ }
120
+
121
+ all_data = np.vstack([coords_by_t[t] for t in sorted(coords_by_t.keys())])
122
+ self.dataset = torch.tensor(all_data, dtype=torch.float32)
123
+ self.tree = cKDTree(all_data)
124
+
125
+ # if whitening is enabled, need to apply this to the full dataset
126
+ #if self.whiten:
127
+ #self.scaler = StandardScaler()
128
+ #self.dataset = torch.tensor(
129
+ #self.scaler.fit_transform(all_data), dtype=torch.float32
130
+ #)
131
+
132
+ self.test_dataloaders = {
133
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
134
+ "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
135
+ }
136
+
137
+ # Metric Dataloader
138
+ # K-means clustering of ALL points into 2 groups
139
+ if self.metric_clusters == 3:
140
+ km_all = KMeans(n_clusters=3, random_state=45).fit(self.dataset.numpy())
141
+ cluster_labels = km_all.labels_
142
+
143
+ cluster_0_mask = cluster_labels == 0
144
+ cluster_1_mask = cluster_labels == 1
145
+ cluster_2_mask = cluster_labels == 2
146
+
147
+ samples = self.dataset.cpu().numpy()
148
+
149
+ cluster_0_data = samples[cluster_0_mask]
150
+ cluster_1_data = samples[cluster_1_mask]
151
+ cluster_2_data = samples[cluster_2_mask]
152
+
153
+ self.metric_samples_dataloaders = [
154
+ DataLoader(
155
+ torch.tensor(cluster_1_data, dtype=torch.float32),
156
+ batch_size=cluster_1_data.shape[0],
157
+ shuffle=False,
158
+ drop_last=False,
159
+ ),
160
+ DataLoader(
161
+ torch.tensor(cluster_2_data, dtype=torch.float32),
162
+ batch_size=cluster_2_data.shape[0],
163
+ shuffle=False,
164
+ drop_last=False,
165
+ ),
166
+
167
+ DataLoader(
168
+ torch.tensor(cluster_0_data, dtype=torch.float32),
169
+ batch_size=cluster_0_data.shape[0],
170
+ shuffle=False,
171
+ drop_last=False,
172
+ ),
173
+ ]
174
+ else:
175
+ km_all = KMeans(n_clusters=2, random_state=45).fit(self.dataset.numpy())
176
+ cluster_labels = km_all.labels_
177
+
178
+ cluster_0_mask = cluster_labels == 0
179
+ cluster_1_mask = cluster_labels == 1
180
+
181
+ samples = self.dataset.cpu().numpy()
182
+
183
+ cluster_0_data = samples[cluster_0_mask]
184
+ cluster_1_data = samples[cluster_1_mask]
185
+
186
+ self.metric_samples_dataloaders = [
187
+ DataLoader(
188
+ torch.tensor(cluster_1_data, dtype=torch.float32),
189
+ batch_size=cluster_1_data.shape[0],
190
+ shuffle=False,
191
+ drop_last=False,
192
+ ),
193
+ DataLoader(
194
+ torch.tensor(cluster_0_data, dtype=torch.float32),
195
+ batch_size=cluster_0_data.shape[0],
196
+ shuffle=False,
197
+ drop_last=False,
198
+ ),
199
+ ]
200
+
201
+
202
+ def train_dataloader(self):
203
+ combined_loaders = {
204
+ "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
205
+ "metric_samples": CombinedLoader(
206
+ self.metric_samples_dataloaders, mode="min_size"
207
+ ),
208
+ }
209
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
210
+
211
+ def val_dataloader(self):
212
+ combined_loaders = {
213
+ "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
214
+ "metric_samples": CombinedLoader(
215
+ self.metric_samples_dataloaders, mode="min_size"
216
+ ),
217
+ }
218
+
219
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
220
+
221
+ def test_dataloader(self):
222
+ combined_loaders = {
223
+ "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
224
+ "metric_samples": CombinedLoader(
225
+ self.metric_samples_dataloaders, mode="min_size"
226
+ ),
227
+ }
228
+
229
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
230
+
231
+ def get_manifold_proj(self, points):
232
+ """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
233
+ return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
234
+
235
+ @staticmethod
236
+ def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
237
+ """
238
+ Apply local smoothing based on k-nearest neighbors in the full dataset
239
+ This replaces the plane projection for 2D manifold regularization
240
+ """
241
+ points_np = x.detach().cpu().numpy()
242
+ _, idx = tree.query(points_np, k=k)
243
+ nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
244
+
245
+ # Compute weighted average of neighbors
246
+ dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
247
+ weights = torch.exp(-dists / temp)
248
+ weights = weights / weights.sum(dim=1, keepdim=True)
249
+
250
+ # Weighted average of neighbors
251
+ smoothed = (weights * nearest_pts).sum(dim=1)
252
+
253
+ # Blend original point with smoothed version
254
+ alpha = 0.3 # How much smoothing to apply
255
+ return (1 - alpha) * x + alpha * smoothed
256
+
257
+ def get_timepoint_data(self):
258
+ """Return data organized by timepoints for visualization"""
259
+ return {
260
+ 't0': self.coords_t0,
261
+ 't1': self.coords_t1,
262
+ 't2_1': self.coords_t2_1,
263
+ 't2_2': self.coords_t2_2,
264
+ 'time_labels': self.time_labels
265
+ }
266
+
267
+
268
+
269
+ class SingleBranchCellDataModule(pl.LightningDataModule):
270
+ def __init__(self, args):
271
+ super().__init__()
272
+ self.save_hyperparameters()
273
+
274
+ self.data_path = "./data/mouse_hematopoiesis.csv"
275
+ self.batch_size = args.batch_size
276
+ self.max_dim = args.dim
277
+ self.whiten = args.whiten
278
+ self.k = 20
279
+ self.n_samples = 1429
280
+ self.num_timesteps = 3 # t=0, t=1, t=2
281
+ self.split_ratios = args.split_ratios
282
+ self.metric_clusters = 3
283
+ self.args = args
284
+ self._prepare_data()
285
+
286
+
287
+ def _prepare_data(self):
288
+ print("Preparing cell data in BranchedCellDataModule")
289
+
290
+ df = pd.read_csv(self.data_path)
291
+
292
+ # Build dictionary of coordinates by time
293
+ coords_by_t = {
294
+ t: df[df["samples"] == t][["x1","x2"]].values
295
+ for t in sorted(df["samples"].unique())
296
+ }
297
+ n0 = coords_by_t[0].shape[0] # Number of T=0 points
298
+ self.n_samples = n0 # Update n_samples to match actual data if changes
299
+
300
+ x0 = torch.tensor(coords_by_t[0], dtype=torch.float32) # T=0 coordinates index
301
+ x_inter = torch.tensor(coords_by_t[1], dtype=torch.float32)
302
+ x1 = torch.tensor(coords_by_t[2], dtype=torch.float32) # Branch index
303
+
304
+ split_index = int(n0 * self.split_ratios[0])
305
+
306
+ if n0 - split_index < self.batch_size:
307
+ split_index = n0 - self.batch_size
308
+
309
+ train_x0 = x0[:split_index]
310
+ val_x0 = x0[split_index:]
311
+ train_x1 = x1[:split_index]
312
+ val_x1 = x1[split_index:]
313
+
314
+ self.val_x0 = val_x0
315
+
316
+ train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
317
+ train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=0.5)
318
+
319
+ val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
320
+ val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=0.5)
321
+
322
+ if self.n_samples - split_index < self.batch_size:
323
+ split_index = self.n_samples - self.batch_size
324
+
325
+ self.train_dataloaders = {
326
+ "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
327
+ "x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
328
+ }
329
+
330
+ self.val_dataloaders = {
331
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
332
+ "x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
333
+ }
334
+
335
+ all_data = np.vstack([coords_by_t[t] for t in sorted(coords_by_t.keys())])
336
+ self.dataset = torch.tensor(all_data, dtype=torch.float32)
337
+ self.tree = cKDTree(all_data)
338
+
339
+ # if whitening is enabled, need to apply this to the full dataset
340
+ if self.whiten:
341
+ self.scaler = StandardScaler()
342
+ self.dataset = torch.tensor(
343
+ self.scaler.fit_transform(all_data), dtype=torch.float32
344
+ )
345
+
346
+ self.test_dataloaders = {
347
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
348
+ "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
349
+ }
350
+
351
+ # Metric Dataloader
352
+ # K-means clustering of ALL points into 2 groups
353
+ km_all = KMeans(n_clusters=2, random_state=45).fit(self.dataset.numpy())
354
+ cluster_labels = km_all.labels_
355
+
356
+ cluster_0_mask = cluster_labels == 0
357
+ cluster_1_mask = cluster_labels == 1
358
+
359
+ samples = self.dataset.cpu().numpy()
360
+
361
+ cluster_0_data = samples[cluster_0_mask]
362
+ cluster_1_data = samples[cluster_1_mask]
363
+
364
+ self.metric_samples_dataloaders = [
365
+ DataLoader(
366
+ torch.tensor(cluster_1_data, dtype=torch.float32),
367
+ batch_size=cluster_1_data.shape[0],
368
+ shuffle=False,
369
+ drop_last=False,
370
+ ),
371
+ DataLoader(
372
+ torch.tensor(cluster_0_data, dtype=torch.float32),
373
+ batch_size=cluster_0_data.shape[0],
374
+ shuffle=False,
375
+ drop_last=False,
376
+ ),
377
+ ]
378
+
379
+
380
+ def train_dataloader(self):
381
+ combined_loaders = {
382
+ "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
383
+ "metric_samples": CombinedLoader(
384
+ self.metric_samples_dataloaders, mode="min_size"
385
+ ),
386
+ }
387
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
388
+
389
+ def val_dataloader(self):
390
+ combined_loaders = {
391
+ "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
392
+ "metric_samples": CombinedLoader(
393
+ self.metric_samples_dataloaders, mode="min_size"
394
+ ),
395
+ }
396
+
397
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
398
+
399
+ def test_dataloader(self):
400
+ combined_loaders = {
401
+ "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
402
+ "metric_samples": CombinedLoader(
403
+ self.metric_samples_dataloaders, mode="min_size"
404
+ ),
405
+ }
406
+
407
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
408
+
409
+ def get_manifold_proj(self, points):
410
+ """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
411
+ return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
412
+
413
+ @staticmethod
414
+ def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
415
+ """
416
+ Apply local smoothing based on k-nearest neighbors in the full dataset
417
+ This replaces the plane projection for 2D manifold regularization
418
+ """
419
+ points_np = x.detach().cpu().numpy()
420
+ _, idx = tree.query(points_np, k=k)
421
+ nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
422
+
423
+ # Compute weighted average of neighbors
424
+ dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
425
+ weights = torch.exp(-dists / temp)
426
+ weights = weights / weights.sum(dim=1, keepdim=True)
427
+
428
+ # Weighted average of neighbors
429
+ smoothed = (weights * nearest_pts).sum(dim=1)
430
+
431
+ # Blend original point with smoothed version
432
+ alpha = 0.3 # How much smoothing to apply
433
+ return (1 - alpha) * x + alpha * smoothed
434
+
435
+ def get_datamodule():
436
+ datamodule = WeightedBranchedCellDataModule(args)
437
+ datamodule.setup(stage="fit")
438
+ return datamodule
dataloaders/three_branch_data.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ sys.argv = ['']
4
+ from sklearn.preprocessing import StandardScaler
5
+ import pytorch_lightning as pl
6
+ from torch.utils.data import DataLoader
7
+ from lightning.pytorch.utilities.combined_loader import CombinedLoader
8
+ import pandas as pd
9
+ import numpy as np
10
+ from functools import partial
11
+ from scipy.spatial import cKDTree
12
+ from sklearn.cluster import KMeans
13
+ from torch.utils.data import TensorDataset
14
+
15
+ from train.parsers_tahoe import parse_args
16
+ args = parse_args()
17
+
18
+ class ThreeBranchTahoeDataModule(pl.LightningDataModule):
19
+ def __init__(self, args):
20
+ super().__init__()
21
+ self.save_hyperparameters()
22
+
23
+ self.batch_size = args.batch_size
24
+ self.max_dim = args.dim
25
+ self.whiten = args.whiten
26
+ self.split_ratios = args.split_ratios
27
+ self.num_timesteps = 2
28
+ self.data_path = "./data/Trametinib_5.0uM_pca_and_leidenumap_labels.csv"
29
+ self.args = args
30
+
31
+ self._prepare_data()
32
+
33
+ def _prepare_data(self):
34
+ df = pd.read_csv(self.data_path, comment='#')
35
+ df = df.iloc[:, 1:]
36
+ df = df.replace('', np.nan)
37
+ pc_cols = df.columns[:50]
38
+ for col in pc_cols:
39
+ df[col] = pd.to_numeric(df[col], errors='coerce')
40
+ leiden_dmso_col = 'leiden_DMSO_TF_0.0uM'
41
+ leiden_clonidine_col = 'leiden_Trametinib_5.0uM'
42
+
43
+ dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column
44
+ clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column
45
+
46
+ dmso_data = df[dmso_mask].copy()
47
+ clonidine_data = df[clonidine_mask].copy()
48
+
49
+ # Updated to include all three clusters: 0, 4, and 6
50
+ top_clonidine_clusters = ['1.0', '3.0', '5.0']
51
+
52
+ x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]]
53
+ x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]]
54
+ x1_3_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[2]]
55
+
56
+ x1_1_coords = x1_1_data[pc_cols].values
57
+ x1_2_coords = x1_2_data[pc_cols].values
58
+ x1_3_coords = x1_3_data[pc_cols].values
59
+
60
+ x1_1_coords = x1_1_coords.astype(float)
61
+ x1_2_coords = x1_2_coords.astype(float)
62
+ x1_3_coords = x1_3_coords.astype(float)
63
+
64
+ # Target size is now the minimum across all three endpoint clusters
65
+ target_size = min(len(x1_1_coords), len(x1_2_coords), len(x1_3_coords))
66
+
67
+ # Helper function to select points closest to centroid
68
+ def select_closest_to_centroid(coords, target_size):
69
+ if len(coords) <= target_size:
70
+ return coords
71
+
72
+ # Calculate centroid
73
+ centroid = np.mean(coords, axis=0)
74
+
75
+ # Calculate distances to centroid
76
+ distances = np.linalg.norm(coords - centroid, axis=1)
77
+
78
+ # Get indices of closest points
79
+ closest_indices = np.argsort(distances)[:target_size]
80
+
81
+ return coords[closest_indices]
82
+
83
+ # Sample all endpoint clusters to target size using centroid-based selection
84
+ x1_1_coords = select_closest_to_centroid(x1_1_coords, target_size)
85
+ x1_2_coords = select_closest_to_centroid(x1_2_coords, target_size)
86
+ x1_3_coords = select_closest_to_centroid(x1_3_coords, target_size)
87
+
88
+ dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts()
89
+
90
+ # DMSO (unchanged)
91
+ largest_dmso_cluster = dmso_cluster_counts.index[0]
92
+ dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster]
93
+
94
+ dmso_coords = dmso_cluster_data[pc_cols].values
95
+
96
+ # Random sampling from largest DMSO cluster to match target size
97
+ # For DMSO, we'll also use centroid-based selection for consistency
98
+ if len(dmso_coords) >= target_size:
99
+ x0_coords = select_closest_to_centroid(dmso_coords, target_size)
100
+ else:
101
+ # If largest cluster is smaller than target, use all of it and pad with other DMSO cells
102
+ remaining_needed = target_size - len(dmso_coords)
103
+ other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster]
104
+ other_dmso_coords = other_dmso_data[pc_cols].values
105
+
106
+ if len(other_dmso_coords) >= remaining_needed:
107
+ # Select closest to centroid from other DMSO cells
108
+ other_selected = select_closest_to_centroid(other_dmso_coords, remaining_needed)
109
+ x0_coords = np.vstack([dmso_coords, other_selected])
110
+ else:
111
+ # Use all available DMSO cells and reduce target size
112
+ all_dmso_coords = dmso_data[pc_cols].values
113
+ target_size = min(target_size, len(all_dmso_coords))
114
+ x0_coords = select_closest_to_centroid(all_dmso_coords, target_size)
115
+
116
+ # Re-select endpoint clusters with updated target size
117
+ x1_1_coords = select_closest_to_centroid(x1_1_data[pc_cols].values.astype(float), target_size)
118
+ x1_2_coords = select_closest_to_centroid(x1_2_data[pc_cols].values.astype(float), target_size)
119
+ x1_3_coords = select_closest_to_centroid(x1_3_data[pc_cols].values.astype(float), target_size)
120
+
121
+ # No need to resample since we already selected the right number
122
+ # The endpoint clusters are already at target_size from centroid-based selection
123
+
124
+ self.n_samples = target_size
125
+
126
+ # for plotting
127
+ self.coords_t0 = torch.tensor(x0_coords, dtype=torch.float32)
128
+ self.coords_t1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
129
+ self.coords_t1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
130
+ self.coords_t1_3 = torch.tensor(x1_3_coords, dtype=torch.float32)
131
+
132
+ self.time_labels = np.concatenate([
133
+ np.zeros(len(self.coords_t0)), # t=0
134
+ np.ones(len(self.coords_t1_1)), # t=1
135
+ np.ones(len(self.coords_t1_2)), # t=1
136
+ np.ones(len(self.coords_t1_3)), # t=1
137
+ ])
138
+
139
+ x0 = torch.tensor(x0_coords, dtype=torch.float32)
140
+ x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
141
+ x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
142
+ x1_3 = torch.tensor(x1_3_coords, dtype=torch.float32)
143
+
144
+ split_index = int(target_size * self.split_ratios[0])
145
+
146
+ if target_size - split_index < self.batch_size:
147
+ split_index = target_size - self.batch_size
148
+
149
+ train_x0 = x0[:split_index]
150
+ val_x0 = x0[split_index:]
151
+ train_x1_1 = x1_1[:split_index]
152
+ val_x1_1 = x1_1[split_index:]
153
+ train_x1_2 = x1_2[:split_index]
154
+ val_x1_2 = x1_2[split_index:]
155
+ train_x1_3 = x1_3[:split_index]
156
+ val_x1_3 = x1_3[split_index:]
157
+
158
+ self.val_x0 = val_x0
159
+
160
+ train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
161
+ train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.603)
162
+ train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.255)
163
+ train_x1_3_weights = torch.full((train_x1_3.shape[0], 1), fill_value=0.142)
164
+
165
+ val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
166
+ val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.603)
167
+ val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.255)
168
+ val_x1_3_weights = torch.full((val_x1_3.shape[0], 1), fill_value=0.142)
169
+
170
+ # Updated train dataloaders to include x1_3
171
+ self.train_dataloaders = {
172
+ "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
173
+ "x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
174
+ "x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
175
+ "x1_3": DataLoader(TensorDataset(train_x1_3, train_x1_3_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
176
+ }
177
+
178
+ # Updated val dataloaders to include x1_3
179
+ self.val_dataloaders = {
180
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
181
+ "x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
182
+ "x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
183
+ "x1_3": DataLoader(TensorDataset(val_x1_3, val_x1_3_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
184
+ }
185
+
186
+ all_coords = df[pc_cols].dropna().values.astype(float)
187
+ self.dataset = torch.tensor(all_coords, dtype=torch.float32)
188
+ self.tree = cKDTree(all_coords)
189
+
190
+ self.test_dataloaders = {
191
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
192
+ "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
193
+ }
194
+
195
+ # Updated metric samples - now using 4 clusters instead of 3
196
+ #km_all = KMeans(n_clusters=4, random_state=42).fit(self.dataset.numpy())
197
+ km_all = KMeans(n_clusters=4, random_state=0).fit(self.dataset[:, :3].numpy())
198
+
199
+ cluster_labels = km_all.labels_
200
+
201
+ cluster_0_mask = cluster_labels == 0
202
+ cluster_1_mask = cluster_labels == 1
203
+ cluster_2_mask = cluster_labels == 2
204
+ cluster_3_mask = cluster_labels == 3
205
+
206
+ samples = self.dataset.cpu().numpy()
207
+
208
+ cluster_0_data = samples[cluster_0_mask]
209
+ cluster_1_data = samples[cluster_1_mask]
210
+ cluster_2_data = samples[cluster_2_mask]
211
+ cluster_3_data = samples[cluster_3_mask]
212
+
213
+ self.metric_samples_dataloaders = [
214
+ DataLoader(
215
+ torch.tensor(cluster_1_data, dtype=torch.float32),
216
+ batch_size=cluster_1_data.shape[0],
217
+ shuffle=False,
218
+ drop_last=False,
219
+ ),
220
+ DataLoader(
221
+ torch.tensor(cluster_3_data, dtype=torch.float32),
222
+ batch_size=cluster_3_data.shape[0],
223
+ shuffle=False,
224
+ drop_last=False,
225
+ ),
226
+
227
+
228
+ DataLoader(
229
+ torch.tensor(cluster_2_data, dtype=torch.float32),
230
+ batch_size=cluster_2_data.shape[0],
231
+ shuffle=False,
232
+ drop_last=False,
233
+ ),
234
+ DataLoader(
235
+ torch.tensor(cluster_0_data, dtype=torch.float32),
236
+ batch_size=cluster_0_data.shape[0],
237
+ shuffle=False,
238
+ drop_last=False,
239
+ ),
240
+ ]
241
+
242
+ def train_dataloader(self):
243
+ combined_loaders = {
244
+ "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
245
+ "metric_samples": CombinedLoader(
246
+ self.metric_samples_dataloaders, mode="min_size"
247
+ ),
248
+ }
249
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
250
+
251
+ def val_dataloader(self):
252
+ combined_loaders = {
253
+ "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
254
+ "metric_samples": CombinedLoader(
255
+ self.metric_samples_dataloaders, mode="min_size"
256
+ ),
257
+ }
258
+
259
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
260
+
261
+ def test_dataloader(self):
262
+ combined_loaders = {
263
+ "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
264
+ "metric_samples": CombinedLoader(
265
+ self.metric_samples_dataloaders, mode="min_size"
266
+ ),
267
+ }
268
+
269
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
270
+
271
+ def get_manifold_proj(self, points):
272
+ """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
273
+ return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
274
+
275
+ @staticmethod
276
+ def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
277
+ """
278
+ Apply local smoothing based on k-nearest neighbors in the full dataset
279
+ This replaces the plane projection for 2D manifold regularization
280
+ """
281
+ points_np = x.detach().cpu().numpy()
282
+ _, idx = tree.query(points_np, k=k)
283
+ nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
284
+
285
+ # Compute weighted average of neighbors
286
+ dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
287
+ weights = torch.exp(-dists / temp)
288
+ weights = weights / weights.sum(dim=1, keepdim=True)
289
+
290
+ # Weighted average of neighbors
291
+ smoothed = (weights * nearest_pts).sum(dim=1)
292
+
293
+ # Blend original point with smoothed version
294
+ alpha = 0.3 # How much smoothing to apply
295
+ return (1 - alpha) * x + alpha * smoothed
296
+
297
+ def get_timepoint_data(self):
298
+ """Return data organized by timepoints for visualization"""
299
+ return {
300
+ 't0': self.coords_t0,
301
+ 't1_1': self.coords_t1_1,
302
+ 't1_2': self.coords_t1_2,
303
+ 't1_3': self.coords_t1_3,
304
+ 'time_labels': self.time_labels
305
+ }
306
+
307
+ def get_datamodule():
308
+ datamodule = ThreeBranchTahoeDataModule(args)
309
+ datamodule.setup(stage="fit")
310
+ return datamodule
dataloaders/trametinib_single.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ sys.argv = ['']
4
+ from sklearn.preprocessing import StandardScaler
5
+ import pytorch_lightning as pl
6
+ from torch.utils.data import DataLoader
7
+ from lightning.pytorch.utilities.combined_loader import CombinedLoader
8
+ import pandas as pd
9
+ import numpy as np
10
+ from functools import partial
11
+ from scipy.spatial import cKDTree
12
+ from sklearn.cluster import KMeans
13
+ from torch.utils.data import TensorDataset
14
+
15
+ from train.parsers_tahoe import parse_args
16
+ args = parse_args()
17
+
18
+ class TrametinibSingleBranchDataModule(pl.LightningDataModule):
19
+ def __init__(self, args):
20
+ super().__init__()
21
+ self.save_hyperparameters()
22
+
23
+ self.batch_size = args.batch_size
24
+ self.max_dim = args.dim
25
+ self.whiten = args.whiten
26
+ self.split_ratios = args.split_ratios
27
+ self.num_timesteps = 2
28
+ self.data_path = "./data/Trametinib_5.0uM_pca_and_leidenumap_labels.csv"
29
+ self.args = args
30
+
31
+ self._prepare_data()
32
+
33
+ def _prepare_data(self):
34
+ df = pd.read_csv(self.data_path, comment='#')
35
+ df = df.iloc[:, 1:]
36
+ df = df.replace('', np.nan)
37
+ pc_cols = df.columns[:50]
38
+ for col in pc_cols:
39
+ df[col] = pd.to_numeric(df[col], errors='coerce')
40
+ leiden_dmso_col = 'leiden_DMSO_TF_0.0uM'
41
+ leiden_clonidine_col = 'leiden_Trametinib_5.0uM'
42
+
43
+ dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column
44
+ clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column
45
+
46
+ dmso_data = df[dmso_mask].copy()
47
+ clonidine_data = df[clonidine_mask].copy()
48
+
49
+ # Updated to include all three clusters: 0, 4, and 6
50
+ top_clonidine_clusters = ['1.0', '3.0', '5.0']
51
+
52
+ x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]]
53
+ x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]]
54
+ x1_3_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[2]]
55
+
56
+ x1_1_coords = x1_1_data[pc_cols].values
57
+ x1_2_coords = x1_2_data[pc_cols].values
58
+ x1_3_coords = x1_3_data[pc_cols].values
59
+
60
+ x1_1_coords = x1_1_coords.astype(float)
61
+ x1_2_coords = x1_2_coords.astype(float)
62
+ x1_3_coords = x1_3_coords.astype(float)
63
+
64
+ # Target size is now the minimum across all three endpoint clusters
65
+ target_size = min(len(x1_1_coords), len(x1_2_coords), len(x1_3_coords))
66
+
67
+ # Helper function to select points closest to centroid
68
+ def select_closest_to_centroid(coords, target_size):
69
+ if len(coords) <= target_size:
70
+ return coords
71
+
72
+ # Calculate centroid
73
+ centroid = np.mean(coords, axis=0)
74
+
75
+ # Calculate distances to centroid
76
+ distances = np.linalg.norm(coords - centroid, axis=1)
77
+
78
+ # Get indices of closest points
79
+ closest_indices = np.argsort(distances)[:target_size]
80
+
81
+ return coords[closest_indices]
82
+
83
+ # Sample all endpoint clusters to target size using centroid-based selection
84
+ x1_1_coords = select_closest_to_centroid(x1_1_coords, target_size)
85
+ x1_2_coords = select_closest_to_centroid(x1_2_coords, target_size)
86
+ x1_3_coords = select_closest_to_centroid(x1_3_coords, target_size)
87
+
88
+ dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts()
89
+
90
+ # DMSO (unchanged)
91
+ largest_dmso_cluster = dmso_cluster_counts.index[0]
92
+ dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster]
93
+
94
+ dmso_coords = dmso_cluster_data[pc_cols].values
95
+
96
+ # Random sampling from largest DMSO cluster to match target size
97
+ # For DMSO, we'll also use centroid-based selection for consistency
98
+ if len(dmso_coords) >= target_size:
99
+ x0_coords = select_closest_to_centroid(dmso_coords, target_size)
100
+ else:
101
+ # If largest cluster is smaller than target, use all of it and pad with other DMSO cells
102
+ remaining_needed = target_size - len(dmso_coords)
103
+ other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster]
104
+ other_dmso_coords = other_dmso_data[pc_cols].values
105
+
106
+ if len(other_dmso_coords) >= remaining_needed:
107
+ # Select closest to centroid from other DMSO cells
108
+ other_selected = select_closest_to_centroid(other_dmso_coords, remaining_needed)
109
+ x0_coords = np.vstack([dmso_coords, other_selected])
110
+ else:
111
+ # Use all available DMSO cells and reduce target size
112
+ all_dmso_coords = dmso_data[pc_cols].values
113
+ target_size = min(target_size, len(all_dmso_coords))
114
+ x0_coords = select_closest_to_centroid(all_dmso_coords, target_size)
115
+
116
+ # Re-select endpoint clusters with updated target size
117
+ x1_1_coords = select_closest_to_centroid(x1_1_data[pc_cols].values.astype(float), target_size)
118
+ x1_2_coords = select_closest_to_centroid(x1_2_data[pc_cols].values.astype(float), target_size)
119
+ x1_3_coords = select_closest_to_centroid(x1_3_data[pc_cols].values.astype(float), target_size)
120
+
121
+ # No need to resample since we already selected the right number
122
+ # The endpoint clusters are already at target_size from centroid-based selection
123
+
124
+ self.n_samples = target_size
125
+
126
+ # for plotting
127
+
128
+
129
+ x0 = torch.tensor(x0_coords, dtype=torch.float32)
130
+ x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
131
+ x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
132
+ x1_3 = torch.tensor(x1_3_coords, dtype=torch.float32)
133
+ x1 = torch.cat([x1_1, x1_2, x1_3], dim=0)
134
+
135
+ self.coords_t0 = x0
136
+ self.coords_t1 = x1
137
+
138
+ self.time_labels = np.concatenate([
139
+ np.zeros(len(self.coords_t0)), # t=0
140
+ np.ones(len(self.coords_t1)), # t=1
141
+ ])
142
+
143
+ split_index = int(target_size * self.split_ratios[0])
144
+
145
+ if target_size - split_index < self.batch_size:
146
+ split_index = target_size - self.batch_size
147
+
148
+ train_x0 = x0[:split_index]
149
+ val_x0 = x0[split_index:]
150
+ train_x1 = x1_1[:split_index]
151
+ val_x1 = x1_1[split_index:]
152
+
153
+ self.val_x0 = val_x0
154
+
155
+ train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
156
+ train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=1.0)
157
+
158
+ val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
159
+ val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=1.0)
160
+
161
+ # Updated train dataloaders to include x1_3
162
+ self.train_dataloaders = {
163
+ "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
164
+ "x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
165
+ }
166
+
167
+ # Updated val dataloaders to include x1_3
168
+ self.val_dataloaders = {
169
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
170
+ "x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
171
+ }
172
+
173
+ all_coords = df[pc_cols].dropna().values.astype(float)
174
+ self.dataset = torch.tensor(all_coords, dtype=torch.float32)
175
+ self.tree = cKDTree(all_coords)
176
+
177
+ self.test_dataloaders = {
178
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
179
+ "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
180
+ }
181
+
182
+ # Updated metric samples - now using 4 clusters instead of 3
183
+ #km_all = KMeans(n_clusters=4, random_state=42).fit(self.dataset.numpy())
184
+ km_all = KMeans(n_clusters=2, random_state=0).fit(self.dataset[:, :3].numpy())
185
+
186
+ cluster_labels = km_all.labels_
187
+
188
+ cluster_0_mask = cluster_labels == 0
189
+ cluster_1_mask = cluster_labels == 1
190
+
191
+ samples = self.dataset.cpu().numpy()
192
+
193
+ cluster_0_data = samples[cluster_0_mask]
194
+ cluster_1_data = samples[cluster_1_mask]
195
+
196
+ self.metric_samples_dataloaders = [
197
+ DataLoader(
198
+ torch.tensor(cluster_1_data, dtype=torch.float32),
199
+ batch_size=cluster_1_data.shape[0],
200
+ shuffle=False,
201
+ drop_last=False,
202
+ ),
203
+ DataLoader(
204
+ torch.tensor(cluster_0_data, dtype=torch.float32),
205
+ batch_size=cluster_0_data.shape[0],
206
+ shuffle=False,
207
+ drop_last=False,
208
+ ),
209
+ ]
210
+
211
+ def train_dataloader(self):
212
+ combined_loaders = {
213
+ "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
214
+ "metric_samples": CombinedLoader(
215
+ self.metric_samples_dataloaders, mode="min_size"
216
+ ),
217
+ }
218
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
219
+
220
+ def val_dataloader(self):
221
+ combined_loaders = {
222
+ "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
223
+ "metric_samples": CombinedLoader(
224
+ self.metric_samples_dataloaders, mode="min_size"
225
+ ),
226
+ }
227
+
228
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
229
+
230
+
231
+
232
+ def test_dataloader(self):
233
+ combined_loaders = {
234
+ "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
235
+ "metric_samples": CombinedLoader(
236
+ self.metric_samples_dataloaders, mode="min_size"
237
+ ),
238
+ }
239
+
240
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
241
+
242
+ def get_manifold_proj(self, points):
243
+ """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
244
+ return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
245
+
246
+ @staticmethod
247
+ def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
248
+ """
249
+ Apply local smoothing based on k-nearest neighbors in the full dataset
250
+ This replaces the plane projection for 2D manifold regularization
251
+ """
252
+ points_np = x.detach().cpu().numpy()
253
+ _, idx = tree.query(points_np, k=k)
254
+ nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
255
+
256
+ # Compute weighted average of neighbors
257
+ dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
258
+ weights = torch.exp(-dists / temp)
259
+ weights = weights / weights.sum(dim=1, keepdim=True)
260
+
261
+ # Weighted average of neighbors
262
+ smoothed = (weights * nearest_pts).sum(dim=1)
263
+
264
+ # Blend original point with smoothed version
265
+ alpha = 0.3 # How much smoothing to apply
266
+ return (1 - alpha) * x + alpha * smoothed
267
+
268
+ def get_timepoint_data(self):
269
+ """Return data organized by timepoints for visualization"""
270
+ return {
271
+ 't0': self.coords_t0,
272
+ 't1': self.coords_t1,
273
+ 'time_labels': self.time_labels
274
+ }
275
+
276
+ def get_datamodule():
277
+ datamodule = TrametinibSingleBranchDataModule(args)
278
+ datamodule.setup(stage="fit")
279
+ return datamodule
losses/.DS_Store ADDED
Binary file (6.15 kB). View file
 
losses/energy_loss.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, math, numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchdiffeq import odeint as odeint2
5
+ from torchmetrics.functional import mean_squared_error
6
+ import ot
7
+
8
+ class EnergySolver(nn.Module):
9
+ def __init__(self, flow_net, growth_net, state_cost, data_manifold_metric=None, samples=None):
10
+ super(EnergySolver, self).__init__()
11
+ self.flow_net = flow_net
12
+ self.growth_net = growth_net
13
+ self.state_cost = state_cost
14
+
15
+ self.data_manifold_metric = data_manifold_metric
16
+ self.samples = samples
17
+
18
+ def forward(self, t, state):
19
+ xt, wt, mt = state
20
+
21
+ xt.requires_grad_(True)
22
+ wt.requires_grad_(True)
23
+ mt.requires_grad_(True)
24
+
25
+ t.requires_grad_(True)
26
+
27
+ ut = self.flow_net(t, xt)
28
+ gt = self.growth_net(t, xt)
29
+
30
+ time=t.expand(xt.shape[0], 1)
31
+ time.requires_grad_(True)
32
+
33
+ dx_dt = ut
34
+ dw_dt = gt
35
+
36
+ if self.data_manifold_metric is not None:
37
+ vel, _, _ = self.data_manifold_metric.calculate_velocity(
38
+ xt, ut, self.samples, 0
39
+ )
40
+
41
+ dm_dt = torch.mean(vel ** 2) * wt
42
+ else:
43
+ dm_dt = ((ut**2).sum(dim =-1) + self.state_cost(xt)) * wt
44
+
45
+ assert xt.shape == dx_dt.shape, f"dx mismatch: expected {xt.shape}, got {dx_dt.shape}"
46
+ assert wt.shape == dw_dt.shape, f"dw mismatch: expected {wt.shape}, got {dw_dt.shape}"
47
+ assert mt.shape == dm_dt.shape, f"dm mismatch: expected {mt.shape}, got {dm_dt.shape}"
48
+ return dx_dt, dw_dt, dm_dt
49
+
50
+ class ReconsLoss(nn.Module):
51
+ def __init__(self, hinge_value=0.01):
52
+ super(ReconsLoss, self).__init__()
53
+ self.hinge_value = hinge_value
54
+
55
+ def __call__(self, source, target, groups = None, to_ignore = None, top_k = 5):
56
+ if groups is not None:
57
+ # for global loss
58
+ c_dist = torch.stack([
59
+ torch.cdist(source[i], target[i])
60
+
61
+ for i in range(1,len(groups))
62
+ if groups[i] != to_ignore
63
+ ])
64
+ else:
65
+ # for local loss
66
+ c_dist = torch.stack([
67
+ torch.cdist(source, target)
68
+ ])
69
+ values, _ = torch.topk(c_dist, top_k, dim=2, largest=False, sorted=False)
70
+ values -= self.hinge_value
71
+ values[values<0] = 0
72
+ loss = torch.mean(values)
73
+ return loss
networks/.DS_Store ADDED
Binary file (6.15 kB). View file
 
networks/flow_mlp.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("./BranchSBM")
3
+ import torch
4
+ from networks.mlp_base import SimpleDenseNet
5
+
6
+
7
+ class VelocityNet(SimpleDenseNet):
8
+ def __init__(self, dim: int, *args, **kwargs):
9
+ super().__init__(input_size=dim + 1, target_size=dim, *args, **kwargs)
10
+
11
+ def forward(self, t, x):
12
+
13
+ if t.dim() < 1 or t.shape[0] != x.shape[0]:
14
+ t = t.repeat(x.shape[0])[:, None]
15
+ if t.dim() < 2:
16
+ t = t[:, None]
17
+ x = torch.cat([t, x], dim=-1)
18
+ return self.model(x)
networks/growth_mlp.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("./BranchSBM")
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import List, Optional
6
+
7
+ from networks.mlp_base import SimpleDenseNet
8
+
9
+ class GrowthNet(SimpleDenseNet):
10
+ def __init__(
11
+ self,
12
+ dim: int,
13
+ activation: str,
14
+ hidden_dims: List[int] = None,
15
+ batch_norm: bool = False,
16
+ negative: bool = False
17
+ ):
18
+ super().__init__(input_size=dim + 1, target_size=1,
19
+ activation=activation,
20
+ batch_norm=batch_norm,
21
+ hidden_dims=hidden_dims)
22
+
23
+ self.softplus = nn.Softplus()
24
+ self.negative = negative
25
+
26
+ def forward(self, t, x):
27
+
28
+ if t.dim() < 1 or t.shape[0] != x.shape[0]:
29
+ t = t.repeat(x.shape[0])[:, None]
30
+ if t.dim() < 2:
31
+ t = t[:, None]
32
+ x = torch.cat([t, x], dim=-1)
33
+ x = self.model(x)
34
+ x = self.softplus(self.model(x))
35
+ if self.negative:
36
+ x = -x
37
+ return x
networks/interpolant_mlp.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("./BranchSBM")
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import List, Optional
6
+
7
+ from networks.mlp_base import SimpleDenseNet
8
+
9
+ class GeoPathMLP(nn.Module):
10
+ def __init__(
11
+ self,
12
+ input_dim: int,
13
+ activation: str,
14
+ batch_norm: bool = True,
15
+ hidden_dims: Optional[List[int]] = None,
16
+ time_geopath: bool = False,
17
+ ):
18
+ super().__init__()
19
+ self.input_dim = input_dim
20
+ self.time_geopath = time_geopath
21
+ self.mainnet = SimpleDenseNet(
22
+ input_size=2 * input_dim + (1 if time_geopath else 0),
23
+ target_size=input_dim,
24
+ activation=activation,
25
+ batch_norm=batch_norm,
26
+ hidden_dims=hidden_dims,
27
+ )
28
+
29
+ def forward(
30
+ self, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor
31
+ ) -> torch.Tensor:
32
+ x = torch.cat([x0, x1], dim=1)
33
+ if self.time_geopath:
34
+ x = torch.cat([x, t], dim=1)
35
+ return self.mainnet(x)
networks/mlp_base.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("./BranchSBM")
3
+ import torch.nn as nn
4
+ import torch
5
+ from typing import List, Optional
6
+
7
+ class swish(nn.Module):
8
+ def forward(self, x):
9
+ return x * torch.sigmoid(x)
10
+
11
+
12
+ ACTIVATION_MAP = {
13
+ "relu": nn.ReLU,
14
+ "sigmoid": nn.Sigmoid,
15
+ "tanh": nn.Tanh,
16
+ "selu": nn.SELU,
17
+ "elu": nn.ELU,
18
+ "lrelu": nn.LeakyReLU,
19
+ "softplus": nn.Softplus,
20
+ "silu": nn.SiLU,
21
+ "swish": swish,
22
+ }
23
+
24
+
25
+ class SimpleDenseNet(nn.Module):
26
+ def __init__(
27
+ self,
28
+ input_size: int,
29
+ target_size: int,
30
+ activation: str,
31
+ batch_norm: bool = False,
32
+ hidden_dims: List[int] = None,
33
+ ):
34
+ super().__init__()
35
+ dims = [input_size, *hidden_dims, target_size]
36
+ layers = []
37
+ for i in range(len(dims) - 2):
38
+ layers.append(nn.Linear(dims[i], dims[i + 1]))
39
+ if batch_norm:
40
+ layers.append(nn.BatchNorm1d(dims[i + 1]))
41
+ layers.append(ACTIVATION_MAP[activation]())
42
+ layers.append(nn.Linear(dims[-2], dims[-1]))
43
+ self.model = nn.Sequential(*layers)
44
+
45
+ def forward(self, x):
46
+ return self.model(x)
networks/utils.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("./BranchSBM")
3
+ import torch
4
+
5
+ class flow_model_torch_wrapper(torch.nn.Module):
6
+ """Wraps model to torchdyn compatible format."""
7
+
8
+ def __init__(self, model):
9
+ super().__init__()
10
+ self.model = model
11
+
12
+ def forward(self, t, x, *args, **kwargs):
13
+ return self.model(t, x)
state_costs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
state_costs/land.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def weighting_function(x, samples, gamma):
5
+ pairwise_sq_diff = (x[:, None, :] - samples[None, :, :]) ** 2
6
+ pairwise_sq_dist = pairwise_sq_diff.sum(-1)
7
+ weights = torch.exp(-pairwise_sq_dist / (2 * gamma**2))
8
+ return weights
9
+
10
+
11
+ def land_metric_tensor(x, samples, gamma, rho):
12
+ weights = weighting_function(x, samples, gamma) # Shape [B, N]
13
+ differences = samples[None, :, :] - x[:, None, :] # Shape [B, N, D]
14
+ squared_differences = differences**2 # Shape [B, N, D]
15
+
16
+ # Compute the sum of weighted squared differences for each dimension
17
+ M_dd_diag = torch.einsum("bn,bnd->bd", weights, squared_differences) + rho
18
+
19
+ # Invert the metric tensor diagonal for each x_t
20
+ M_dd_inv_diag = 1.0 / M_dd_diag # Shape [B, D] since it's diagonal
21
+ return M_dd_inv_diag
22
+
23
+
24
+ def weighting_function_dt(x, dx_dt, samples, gamma, weights):
25
+ pairwise_sq_diff_dt = (x[:, None, :] - samples[None, :, :]) * dx_dt[:, None, :]
26
+ return -pairwise_sq_diff_dt.sum(-1) * weights / (gamma**2)
state_costs/metric_factory.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("./BranchSBM")
3
+ import torch
4
+ import pytorch_lightning as pl
5
+ from pytorch_lightning.loggers import WandbLogger
6
+ from torch.utils.data import Dataset, DataLoader
7
+
8
+ from state_costs.land import land_metric_tensor
9
+ from state_costs.rbf import RBFNetwork
10
+
11
+ class DataManifoldMetric:
12
+ def __init__(
13
+ self,
14
+ args,
15
+ skipped_time_points=None,
16
+ datamodule=None,
17
+ ):
18
+ self.skipped_time_points = skipped_time_points
19
+ self.datamodule = datamodule
20
+
21
+ self.gamma = args.gamma_current
22
+ self.rho = args.rho
23
+ self.metric = args.velocity_metric
24
+ self.n_centers = args.n_centers
25
+ self.kappa = args.kappa
26
+ self.metric_epochs = args.metric_epochs
27
+ self.metric_patience = args.metric_patience
28
+ self.lr = args.metric_lr
29
+ self.alpha_metric = args.alpha_metric
30
+ self.image_data = args.data_type == "image"
31
+ self.accelerator = args.accelerator
32
+
33
+ self.called_first_time = True
34
+ self.args = args
35
+
36
+ def calculate_metric(self, x_t, samples, current_timestep):
37
+ if self.metric == "land":
38
+ M_dd_x_t = (
39
+ land_metric_tensor(x_t, samples, self.gamma, self.rho)
40
+ ** self.alpha_metric
41
+ )
42
+
43
+ elif self.metric == "rbf":
44
+ if self.called_first_time:
45
+ self.rbf_networks = []
46
+ for timestep in range(self.datamodule.num_timesteps - 1):
47
+ if timestep in self.skipped_time_points:
48
+ continue
49
+ print("Learning RBF networks, timestep: ", timestep)
50
+ rbf_network = RBFNetwork(
51
+ current_timestep=timestep,
52
+ next_timestep=timestep
53
+ + 1
54
+ + (1 if timestep + 1 in self.skipped_time_points else 0),
55
+ n_centers=self.n_centers,
56
+ kappa=self.kappa,
57
+ lr=self.lr,
58
+ datamodule=self.datamodule,
59
+ args=self.args
60
+ )
61
+ early_stop_callback = pl.callbacks.EarlyStopping(
62
+ monitor="MetricModel/val_loss_learn_metric",
63
+ patience=self.metric_patience,
64
+ mode="min",
65
+ )
66
+ trainer = pl.Trainer(
67
+ max_epochs=self.metric_epochs,
68
+ accelerator=self.accelerator,
69
+ logger=WandbLogger(),
70
+ num_sanity_val_steps=0,
71
+ callbacks=(
72
+ [early_stop_callback] if not self.image_data else None
73
+ ),
74
+ )
75
+ if self.image_data:
76
+ self.dataloader = DataLoader(
77
+ self.datamodule.all_data,
78
+ batch_size=128,
79
+ shuffle=True,
80
+ )
81
+ trainer.fit(rbf_network, self.dataloader)
82
+ else:
83
+ trainer.fit(rbf_network, self.datamodule)
84
+ self.rbf_networks.append(rbf_network)
85
+ self.called_first_time = False
86
+ print("Learning RBF networksss... Done")
87
+ M_dd_x_t = self.rbf_networks[current_timestep].compute_metric(
88
+ x_t,
89
+ epsilon=self.rho,
90
+ alpha=self.alpha_metric,
91
+ image_hx=self.image_data,
92
+ )
93
+ return M_dd_x_t
94
+
95
+ def calculate_velocity(self, x_t, u_t, samples, timestep):
96
+
97
+ if len(u_t.shape) > 2:
98
+ u_t = u_t.reshape(u_t.shape[0], -1)
99
+ x_t = x_t.reshape(x_t.shape[0], -1)
100
+ M_dd_x_t = self.calculate_metric(x_t, samples, timestep).to(u_t.device)
101
+
102
+ velocity = torch.sqrt(((u_t**2) * M_dd_x_t).sum(dim=-1))
103
+ ut_sum = (u_t**2).sum(dim=-1)
104
+ metric_sum = M_dd_x_t.sum(dim=-1)
105
+ return velocity, ut_sum, metric_sum
state_costs/rbf.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ from sklearn.cluster import KMeans
4
+ import numpy as np
5
+
6
+ class RBFNetwork(pl.LightningModule):
7
+ def __init__(
8
+ self,
9
+ current_timestep,
10
+ next_timestep,
11
+ n_centers: int = 100,
12
+ kappa: float = 1.0,
13
+ lr=1e-2,
14
+ datamodule=None,
15
+ image_data=False,
16
+ args=None
17
+ ):
18
+ super().__init__()
19
+ self.K = n_centers
20
+ self.current_timestep = current_timestep
21
+ self.next_timestep = next_timestep
22
+ self.clustering_model = KMeans(n_clusters=self.K)
23
+ self.kappa = kappa
24
+ self.last_val_loss = 1
25
+ self.lr = lr
26
+ self.W = torch.nn.Parameter(torch.rand(self.K, 1))
27
+ self.datamodule = datamodule
28
+ self.image_data = image_data
29
+ self.args = args
30
+
31
+ def on_before_zero_grad(self, *args, **kwargs):
32
+ self.W.data = torch.clamp(self.W.data, min=0.0001)
33
+
34
+ def on_train_start(self):
35
+ with torch.no_grad():
36
+
37
+ batch = next(iter(self.trainer.datamodule.train_dataloader()))
38
+
39
+ metric_samples = batch[0]["metric_samples"][0]
40
+ all_data = torch.cat(metric_samples)
41
+ data_to_fit = all_data
42
+
43
+ print("Fitting Clustering model...")
44
+ self.clustering_model.fit(data_to_fit)
45
+
46
+ clusters = (
47
+ self.calculate_centroids(all_data, self.clustering_model.labels_)
48
+ if self.image_data
49
+ else self.clustering_model.cluster_centers_
50
+ )
51
+
52
+ self.C = torch.tensor(clusters, dtype=torch.float32).to(self.device)
53
+ labels = self.clustering_model.labels_
54
+ sigmas = np.zeros((self.K, 1))
55
+
56
+ for k in range(self.K):
57
+ points = all_data[labels == k, :]
58
+ variance = ((points - clusters[k]) ** 2).mean(axis=0)
59
+ sigmas[k, :] = np.sqrt(
60
+ variance.sum() if self.image_data else variance.mean()
61
+ )
62
+
63
+ self.lamda = torch.tensor(
64
+ 0.5 / (self.kappa * sigmas) ** 2, dtype=torch.float32
65
+ ).to(self.device)
66
+
67
+ def forward(self, x):
68
+ if len(x.shape) > 2:
69
+ x = x.reshape(x.shape[0], -1).to(self.C.device)
70
+
71
+ x = x.to(self.C.device)
72
+ dist2 = torch.cdist(x, self.C) ** 2
73
+ self.phi_x = torch.exp(-0.5 * self.lamda[None, :, :] * dist2[:, :, None])
74
+
75
+ h_x = (self.W.to(x.device) * self.phi_x).sum(dim=1)
76
+
77
+ return h_x
78
+
79
+ def training_step(self, batch, batch_idx):
80
+ if self.args.data_type == "scrna" or self.args.data_type == "tahoe":
81
+ main_batch = batch[0]["train_samples"][0]
82
+ else:
83
+ main_batch = batch["train_samples"][0]
84
+
85
+ x0 = main_batch["x0"][0]
86
+ if self.args.branches == 1:
87
+ x1 = main_batch["x1"][0]
88
+ inputs = torch.cat([x0, x1], dim=0).to(self.device)
89
+ else:
90
+ x1_1 = main_batch["x1_1"][0]
91
+ x1_2 = main_batch["x1_2"][0]
92
+
93
+ inputs = torch.cat([x0, x1_1, x1_2], dim=0).to(self.device)
94
+ print("inputs shape")
95
+ print(inputs.shape)
96
+
97
+ loss = ((1 - self.forward(inputs)) ** 2).mean()
98
+ self.log(
99
+ "MetricModel/train_loss_learn_metric",
100
+ loss,
101
+ on_step=True,
102
+ on_epoch=True,
103
+ prog_bar=True,
104
+ )
105
+ return loss
106
+
107
+ def validation_step(self, batch, batch_idx):
108
+ if self.args.data_type == "scrna" or self.args.data_type == "tahoe":
109
+ main_batch = batch[0]["val_samples"][0]
110
+ else:
111
+ main_batch = batch["val_samples"][0]
112
+
113
+ x0 = main_batch["x0"][0]
114
+ if self.args.branches == 1:
115
+ x1 = main_batch["x1"][0]
116
+ inputs = torch.cat([x0, x1], dim=0).to(self.device)
117
+ else:
118
+ x1_1 = main_batch["x1_1"][0]
119
+ x1_2 = main_batch["x1_2"][0]
120
+
121
+ inputs = torch.cat([x0, x1_1, x1_2], dim=0).to(self.device)
122
+
123
+ h = self.forward(inputs)
124
+
125
+ loss = ((1 - h) ** 2).mean()
126
+ self.log(
127
+ "MetricModel/val_loss_learn_metric",
128
+ loss,
129
+ on_step=True,
130
+ on_epoch=True,
131
+ prog_bar=True,
132
+ )
133
+ self.last_val_loss = loss.detach()
134
+ return loss
135
+
136
+ def calculate_centroids(self, all_data, labels):
137
+ unique_labels = np.unique(labels)
138
+ centroids = np.zeros((len(unique_labels), all_data.shape[1]))
139
+ for i, label in enumerate(unique_labels):
140
+ centroids[i] = all_data[labels == label].mean(axis=0)
141
+ return centroids
142
+
143
+ def configure_optimizers(self):
144
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
145
+ return optimizer
146
+
147
+ def compute_metric(self, x, alpha=1, epsilon=1e-2, image_hx=False):
148
+ if epsilon < 0:
149
+ epsilon = (1 - self.last_val_loss.item()) / abs(epsilon)
150
+ h_x = self.forward(x)
151
+ if image_hx:
152
+ h_x = 1 - torch.abs(1 - h_x)
153
+ M_x = 1 / (h_x**alpha + epsilon)
154
+ else:
155
+ M_x = 1 / (h_x + epsilon) ** alpha
156
+ return M_x
train/.DS_Store ADDED
Binary file (6.15 kB). View file
 
train/main_branches.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("./BranchSBM")
3
+ import os
4
+ import sys
5
+ import argparse
6
+ import copy
7
+ from pytorch_lightning import Trainer
8
+ from pytorch_lightning.loggers import WandbLogger
9
+ import wandb
10
+ import hydra
11
+ from omegaconf import DictConfig, OmegaConf
12
+ from torchcfm.optimal_transport import OTPlanSampler
13
+
14
+ from branchsbm.branchsbm import BranchSBM
15
+ from branchsbm.branch_flow_net_train import FlowNetTrainCell, FlowNetTrainLidar
16
+ from branchsbm.branch_interpolant_train import BranchInterpolantTrain
17
+
18
+ from dataloaders.trajectory_data import TemporalDataModule
19
+ from dataloaders.mouse_data import WeightedBranchedCellDataModule
20
+ from dataloaders.three_branch_data import ThreeBranchTahoeDataModule
21
+ from dataloaders.clonidine_v2_data import ClonidineV2DataModule
22
+ from dataloaders.clonidine_single_branch import ClonidineSingleBranchDataModule
23
+ from dataloaders.trametinib_single import TrametinibSingleBranchDataModule
24
+ from dataloaders.lidar_data import WeightedBranchedLidarDataModule
25
+ from dataloaders.lidar_data_single import LidarSingleDataModule
26
+
27
+ from networks.flow_networks.mlp import VelocityNet
28
+ from networks.growth_networks.mlp import GrowthNet
29
+ from networks.geopath_networks.mlp import GeoPathMLP
30
+ from networks.unet_base import UNetModelWrapper as UNetModel
31
+ from networks.geopath_networks.unet import GeoPathUNet
32
+ from utils import set_seed
33
+
34
+ from train.parsers import parse_args
35
+ from flow_matchers.ema import EMA
36
+ from train.train_utils import (
37
+ load_config,
38
+ merge_config,
39
+ generate_group_string,
40
+ dataset_name2datapath,
41
+ create_callbacks,
42
+ )
43
+ from geo_metrics.metric_factory import DataManifoldMetric
44
+ import torch.nn as nn
45
+ from flow_matchers.branch_growth_net_train import GrowthNetTrain, GrowthNetTrainCell, GrowthNetTrainLidar
46
+
47
+ def main(args: argparse.Namespace, seed: int, t_exclude: int) -> None:
48
+ set_seed(seed)
49
+ branches = args.branches
50
+
51
+ skipped_time_points = [t_exclude] if t_exclude else []
52
+
53
+ ### DATAMODULES ###
54
+ if args.data_name == "lidar":
55
+ datamodule = WeightedBranchedLidarDataModule(args=args)
56
+ elif args.data_name == "lidarsingle":
57
+ datamodule = LidarSingleDataModule(args=args)
58
+ elif args.data_name == "mouse":
59
+ datamodule = WeightedBranchedCellDataModule(args=args)
60
+ elif args.data_name in ["clonidine50D", "clonidine100D", "clonidine150D"]:
61
+ datamodule = ClonidineV2DataModule(args=args)
62
+ elif args.data_name == "clonidine50Dsingle":
63
+ datamodule = ClonidineSingleBranchDataModule(args=args)
64
+ elif args.data_name == "trametinib":
65
+ datamodule = ThreeBranchTahoeDataModule(args=args)
66
+ elif args.data_name == "trametinibsingle":
67
+ datamodule = TrametinibSingleBranchDataModule(args=args)
68
+
69
+ flow_nets = nn.ModuleList()
70
+ geopath_nets = nn.ModuleList()
71
+ growth_nets = nn.ModuleList()
72
+
73
+ ##### initialize branched flow and growth networks #####
74
+ for i in range(branches):
75
+ flow_net = VelocityNet(
76
+ dim=args.dim,
77
+ hidden_dims=args.hidden_dims_flow,
78
+ activation=args.activation_flow,
79
+ batch_norm=False,
80
+ )
81
+ geopath_net = GeoPathMLP(
82
+ input_dim=args.dim,
83
+ hidden_dims=args.hidden_dims_geopath,
84
+ time_geopath=args.time_geopath,
85
+ activation=args.activation_geopath,
86
+ batch_norm=False,
87
+ )
88
+
89
+ if i == 0:
90
+ growth_net = GrowthNet(
91
+ dim=args.dim,
92
+ hidden_dims=args.hidden_dims_growth,
93
+ activation=args.activation_growth,
94
+ batch_norm=False,
95
+ negative=True
96
+ )
97
+ else:
98
+ growth_net = GrowthNet(
99
+ dim=args.dim,
100
+ hidden_dims=args.hidden_dims_growth,
101
+ activation=args.activation_growth,
102
+ batch_norm=False,
103
+ negative=False
104
+ )
105
+
106
+ if args.ema_decay is not None:
107
+ flow_net = EMA(model=flow_net, decay=args.ema_decay)
108
+ geopath_net = EMA(model=geopath_net, decay=args.ema_decay)
109
+ growth_net = EMA(model=growth_net, decay=args.ema_decay)
110
+
111
+ flow_nets.append(flow_net)
112
+ geopath_nets.append(geopath_net)
113
+ growth_nets.append(growth_net)
114
+
115
+
116
+ ot_sampler = (
117
+ OTPlanSampler(method=args.optimal_transport_method)
118
+ if args.optimal_transport_method != "None"
119
+ else None
120
+ )
121
+
122
+ wandb.init(
123
+ project=f"branchsbm-{args.data_name}-{branches}-branches",
124
+ group=args.group_name,
125
+ config=vars(args),
126
+ dir=args.working_dir,
127
+ )
128
+
129
+ flow_matcher_base = BranchSBM(
130
+ geopath_nets=geopath_nets,
131
+ sigma=args.sigma,
132
+ alpha=int(args.branchsbm),
133
+ )
134
+
135
+ ##### STAGE 1: Training of Geodesic Interpolants Beginning #####
136
+
137
+ geopath_callbacks = create_callbacks(
138
+ args, phase="geopath", data_type=args.data_type, run_id=wandb.run.id
139
+ )
140
+
141
+ # define state cost
142
+ data_manifold_metric = DataManifoldMetric(
143
+ args=args,
144
+ skipped_time_points=skipped_time_points,
145
+ datamodule=datamodule,
146
+ )
147
+ geopath_model = BranchInterpolantTrain(
148
+ flow_matcher=flow_matcher_base,
149
+ skipped_time_points=skipped_time_points,
150
+ ot_sampler=ot_sampler,
151
+ args=args,
152
+ data_manifold_metric=data_manifold_metric
153
+ )
154
+
155
+ wandb_logger = WandbLogger()
156
+
157
+ trainer = Trainer(
158
+ max_epochs=args.epochs,
159
+ callbacks=geopath_callbacks,
160
+ accelerator=args.accelerator,
161
+ logger=wandb_logger,
162
+ num_sanity_val_steps=0,
163
+ default_root_dir=args.working_dir,
164
+ gradient_clip_val=(1.0 if args.data_type == "image" else None),
165
+ )
166
+
167
+ if args.load_geopath_model_ckpt:
168
+ best_model_path = args.load_geopath_model_ckpt
169
+ else:
170
+ trainer.fit(
171
+ geopath_model,
172
+ datamodule=datamodule,
173
+ )
174
+
175
+ best_model_path = geopath_callbacks[0].best_model_path
176
+
177
+ geopath_model = BranchInterpolantTrain.load_from_checkpoint(best_model_path)
178
+
179
+ flow_matcher_base.geopath_nets = geopath_model.geopath_nets
180
+
181
+ ##### STAGE 1: Training of Geodesic Interpolants End #####
182
+
183
+
184
+ ##### STAGE 2: Flow Matching Beginning #####
185
+ flow_callbacks = create_callbacks(
186
+ args,
187
+ phase="flow",
188
+ data_type=args.data_type,
189
+ run_id=wandb.run.id,
190
+ datamodule=datamodule,
191
+ )
192
+ if args.data_type == "lidar":
193
+ FlowNetTrain = FlowNetTrainLidar
194
+ else:
195
+ FlowNetTrain = FlowNetTrainCell
196
+
197
+ flow_train = FlowNetTrain(
198
+ flow_matcher=flow_matcher_base,
199
+ flow_nets=flow_nets,
200
+ ot_sampler=ot_sampler,
201
+ skipped_time_points=skipped_time_points,
202
+ args=args,
203
+ )
204
+
205
+ wandb_logger = WandbLogger()
206
+
207
+ trainer = Trainer(
208
+ max_epochs=args.epochs,
209
+ callbacks=flow_callbacks,
210
+ check_val_every_n_epoch=args.check_val_every_n_epoch,
211
+ accelerator=args.accelerator,
212
+ logger=wandb_logger,
213
+ default_root_dir=args.working_dir,
214
+ gradient_clip_val=(1.0 if args.data_type == "image" else None),
215
+ num_sanity_val_steps=(0 if args.data_type == "image" else None),
216
+ )
217
+
218
+ trainer.fit(
219
+ flow_train, datamodule=datamodule, ckpt_path=args.resume_flow_model_ckpt
220
+ )
221
+ if args.data_type == "lidar":
222
+ trainer.test(flow_train, datamodule=datamodule)
223
+ ##### STAGE 2: Flow Matching End #####
224
+
225
+
226
+ ##### STAGE 3: Training Growth Networks Beginning ####
227
+ flow_nets = flow_train.flow_nets
228
+
229
+ growth_callbacks = create_callbacks(
230
+ args,
231
+ phase="growth",
232
+ data_type=args.data_type,
233
+ run_id=wandb.run.id,
234
+ datamodule=datamodule,
235
+ )
236
+
237
+ if args.data_type == "lidar":
238
+ GrowthNetTrain = GrowthNetTrainLidar
239
+ else:
240
+ GrowthNetTrain = GrowthNetTrainCell
241
+
242
+ growth_train = GrowthNetTrain(
243
+ flow_nets = flow_nets,
244
+ growth_nets = growth_nets,
245
+ ot_sampler=ot_sampler,
246
+ skipped_time_points=skipped_time_points,
247
+ args=args,
248
+ data_manifold_metric=data_manifold_metric,
249
+ joint = False
250
+ )
251
+
252
+ wandb_logger = WandbLogger()
253
+
254
+ trainer = Trainer(
255
+ max_epochs=args.epochs,
256
+ callbacks=growth_callbacks,
257
+ check_val_every_n_epoch=args.check_val_every_n_epoch,
258
+ accelerator=args.accelerator,
259
+ logger=wandb_logger,
260
+ default_root_dir=args.working_dir,
261
+ gradient_clip_val=(1.0 if args.data_type == "image" else None),
262
+ num_sanity_val_steps=(0 if args.data_type == "image" else None),
263
+ )
264
+
265
+ trainer.fit(
266
+ growth_train, datamodule=datamodule, ckpt_path=None
267
+ )
268
+ trainer.test(growth_train, datamodule=datamodule)
269
+
270
+ ##### STAGE 3: Training Growth Networks End ####
271
+
272
+ ##### STAGE 4: Joint Training Beginning ####
273
+
274
+ growth_nets = growth_train.growth_nets
275
+
276
+ joint_callbacks = create_callbacks(
277
+ args,
278
+ phase="joint",
279
+ data_type=args.data_type,
280
+ run_id=wandb.run.id,
281
+ datamodule=datamodule,
282
+ )
283
+
284
+ if args.data_type == "lidar":
285
+ GrowthNetTrain = GrowthNetTrainLidar
286
+ else:
287
+ GrowthNetTrain = GrowthNetTrainCell
288
+
289
+ joint_train = GrowthNetTrain(
290
+ flow_nets = flow_nets,
291
+ growth_nets = growth_nets,
292
+ ot_sampler=ot_sampler,
293
+ skipped_time_points=skipped_time_points,
294
+ args=args,
295
+ data_manifold_metric=data_manifold_metric,
296
+ joint = True
297
+ )
298
+
299
+ wandb_logger = WandbLogger()
300
+
301
+ trainer = Trainer(
302
+ max_epochs=args.epochs,
303
+ callbacks=joint_callbacks,
304
+ check_val_every_n_epoch=args.check_val_every_n_epoch,
305
+ accelerator=args.accelerator,
306
+ logger=wandb_logger,
307
+ default_root_dir=args.working_dir,
308
+ gradient_clip_val=(1.0 if args.data_type == "image" else None),
309
+ num_sanity_val_steps=(0 if args.data_type == "image" else None),
310
+ )
311
+
312
+ trainer.fit(
313
+ joint_train, datamodule=datamodule, ckpt_path=None
314
+ )
315
+ trainer.test(joint_train, datamodule=datamodule)
316
+
317
+ ##### STAGE 4: Joint Training End ####
318
+
319
+ wandb.finish()
320
+
321
+ if __name__ == "__main__":
322
+ args = parse_args()
323
+ updated_args = copy.deepcopy(args)
324
+ if args.config_path:
325
+ config = load_config(args.config_path)
326
+ updated_args = merge_config(updated_args, config)
327
+
328
+ updated_args.group_name = generate_group_string()
329
+ updated_args.data_path = dataset_name2datapath(
330
+ updated_args.data_name, updated_args.working_dir
331
+ )
332
+ for seed in updated_args.seeds:
333
+ if updated_args.t_exclude:
334
+ for i, t_exclude in enumerate(updated_args.t_exclude):
335
+ updated_args.t_exclude_current = t_exclude
336
+ updated_args.seed_current = seed
337
+ updated_args.gamma_current = updated_args.gammas[i]
338
+ main(updated_args, seed=seed, t_exclude=t_exclude)
339
+ else:
340
+ updated_args.seed_current = seed
341
+ updated_args.gamma_current = updated_args.gammas[0]
342
+ main(updated_args, seed=seed, t_exclude=None)
train/parsers.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ def parse_args():
4
+ parser = argparse.ArgumentParser(description="Train BranchSBM")
5
+
6
+ parser.add_argument(
7
+ "--config_path", type=str,
8
+ default='./configs/experiment/lidar.yaml',
9
+ help="Path to config file"
10
+ )
11
+
12
+ ####### ITERATES IN THE CODE #######
13
+ parser.add_argument(
14
+ "--seeds",
15
+ nargs="+",
16
+ type=int,
17
+ default=[42, 43, 44, 45, 46],
18
+ help="Random seeds to iterate over",
19
+ )
20
+ parser.add_argument(
21
+ "--t_exclude",
22
+ nargs="+",
23
+ type=int,
24
+ default=[1, 2],
25
+ help="Time points to exclude (iterating over)",
26
+ )
27
+ ####################################
28
+
29
+ parser.add_argument(
30
+ "--working_dir",
31
+ type=str,
32
+ default="./",
33
+ help="Working directory",
34
+ )
35
+ parser.add_argument(
36
+ "--resume_flow_model_ckpt",
37
+ type=str,
38
+ default=None,
39
+ help="Path to the flow model to resume training",
40
+ )
41
+ parser.add_argument(
42
+ "--resume_growth_model_ckpt",
43
+ type=str,
44
+ default=None,
45
+ help="Path to the flow model to resume training",
46
+ )
47
+ parser.add_argument(
48
+ "--load_geopath_model_ckpt",
49
+ type=str,
50
+ default=None,
51
+ help="Path to the geopath model to resume training",
52
+ )
53
+ parser.add_argument(
54
+ "--branches",
55
+ type=int,
56
+ default=2,
57
+ help="Number of branches",
58
+ )
59
+ parser.add_argument(
60
+ "--metric_clusters",
61
+ type=int,
62
+ default=3,
63
+ help="Number of metric clusters",
64
+ )
65
+
66
+ ######### DATASETS #################
67
+ parser = datasets_parser(parser)
68
+ ####################################
69
+
70
+ ######### IMAGE DATASETS ###########
71
+ parser = image_datasets_parser(parser)
72
+ ####################################
73
+
74
+ ######### METRICS ##################
75
+ parser = metric_parser(parser)
76
+ ####################################
77
+
78
+ ######### General Training #########
79
+ parser = general_training_parser(parser)
80
+ ####################################
81
+
82
+ ######### Training GeoPath Network ####
83
+ parser = geopath_network_parser(parser)
84
+ ####################################
85
+
86
+ ######### Training Flow Network ####
87
+ parser = flow_network_parser(parser)
88
+ ####################################
89
+
90
+ parser = growth_network_parser(parser)
91
+
92
+ return parser.parse_args()
93
+
94
+
95
+ def datasets_parser(parser):
96
+ parser.add_argument("--dim", type=int, default=3, help="Dimension of data")
97
+
98
+ parser.add_argument(
99
+ "--data_type",
100
+ type=str,
101
+ default="lidar",
102
+ help="Type of data, now wither scrna or one of toys",
103
+ )
104
+ parser.add_argument(
105
+ "--data_path",
106
+ type=str,
107
+ default="./data/rainier2-thin.las",
108
+ help="lidar data path",
109
+ )
110
+ parser.add_argument(
111
+ "--data_name",
112
+ type=str,
113
+ default="lidar",
114
+ help="Path to the dataset",
115
+ )
116
+ parser.add_argument(
117
+ "--whiten",
118
+ action=argparse.BooleanOptionalAction,
119
+ default=True,
120
+ help="Whiten the data",
121
+ )
122
+ return parser
123
+
124
+
125
+ def image_datasets_parser(parser):
126
+ parser.add_argument(
127
+ "--image_size",
128
+ type=int,
129
+ default=128,
130
+ help="Size of the image",
131
+ )
132
+ parser.add_argument(
133
+ "--x0_label",
134
+ type=str,
135
+ default="dog",
136
+ help="Label for x0",
137
+ )
138
+ parser.add_argument(
139
+ "--x1_label",
140
+ type=str,
141
+ default="cat",
142
+ help="Label for x1",
143
+ )
144
+ return parser
145
+
146
+
147
+ def metric_parser(parser):
148
+ parser.add_argument(
149
+ "--branchsbm",
150
+ action=argparse.BooleanOptionalAction,
151
+ default=True,
152
+ help="If branched SBM",
153
+ )
154
+ parser.add_argument(
155
+ "--n_centers",
156
+ type=int,
157
+ default=100,
158
+ help="Number of centers for RBF network",
159
+ )
160
+ parser.add_argument(
161
+ "--kappa",
162
+ type=float,
163
+ default=1.0,
164
+ help="Kappa parameter for RBF network",
165
+ )
166
+ parser.add_argument(
167
+ "--rho",
168
+ type=float,
169
+ default=0.001,
170
+ help="Rho parameter in Riemanian Velocity Calculation",
171
+ )
172
+ parser.add_argument(
173
+ "--velocity_metric",
174
+ type=str,
175
+ default="rbf",
176
+ help="Metric for velocity calculation",
177
+ )
178
+ parser.add_argument(
179
+ "--gammas",
180
+ nargs="+",
181
+ type=float,
182
+ default=[0.2, 0.2],
183
+ help="Gamma parameter in Riemanian Velocity Calculation",
184
+ )
185
+ parser.add_argument(
186
+ "--metric_epochs",
187
+ type=int,
188
+ default=50,
189
+ help="Number of epochs for metric learning",
190
+ )
191
+ parser.add_argument(
192
+ "--metric_patience",
193
+ type=int,
194
+ default=5,
195
+ help="Patience for metric learning",
196
+ )
197
+ parser.add_argument(
198
+ "--metric_lr",
199
+ type=float,
200
+ default=1e-2,
201
+ help="Learning rate for metric learning",
202
+ )
203
+ parser.add_argument(
204
+ "--alpha_metric",
205
+ type=float,
206
+ default=1.0,
207
+ help="Alpha parameter for metric learning",
208
+ )
209
+
210
+ return parser
211
+
212
+
213
+ def general_training_parser(parser):
214
+ parser.add_argument(
215
+ "--batch_size", type=int, default=128, help="Batch size for training"
216
+ )
217
+ parser.add_argument(
218
+ "--optimal_transport_method",
219
+ type=str,
220
+ default="exact",
221
+ help="Use optimal transport in CFM training",
222
+ )
223
+ parser.add_argument(
224
+ "--ema_decay",
225
+ type=float,
226
+ default=None,
227
+ help="Decay for EMA",
228
+ )
229
+ parser.add_argument(
230
+ "--split_ratios",
231
+ nargs=2,
232
+ type=float,
233
+ default=[0.9, 0.1],
234
+ help="Split ratios for training/validation data in CFM training",
235
+ )
236
+ parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
237
+ parser.add_argument(
238
+ "--accelerator", type=str, default="cpu", help="Training accelerator"
239
+ )
240
+ parser.add_argument(
241
+ "--sim_num_steps",
242
+ type=int,
243
+ default=1000,
244
+ help="Number of steps in simulation",
245
+ )
246
+ return parser
247
+
248
+
249
+ def geopath_network_parser(parser):
250
+ parser.add_argument(
251
+ "--manifold",
252
+ action=argparse.BooleanOptionalAction,
253
+ default=True,
254
+ help="If use data manifold metric",
255
+ )
256
+ parser.add_argument(
257
+ "--patience_geopath",
258
+ type=int,
259
+ default=50,
260
+ help="Patience for training geopath model",
261
+ )
262
+ parser.add_argument(
263
+ "--hidden_dims_geopath",
264
+ nargs="+",
265
+ type=int,
266
+ default=[64, 64, 64],
267
+ help="Dimensions of hidden layers for GeoPath model training",
268
+ )
269
+ parser.add_argument(
270
+ "--time_geopath",
271
+ action=argparse.BooleanOptionalAction,
272
+ default=False,
273
+ help="Use time in GeoPath model",
274
+ )
275
+ parser.add_argument(
276
+ "--activation_geopath",
277
+ type=str,
278
+ default="selu",
279
+ help="Activation function for GeoPath",
280
+ )
281
+ parser.add_argument(
282
+ "--geopath_optimizer",
283
+ type=str,
284
+ default="adam",
285
+ help="Optimizer for GeoPath training",
286
+ )
287
+ parser.add_argument(
288
+ "--geopath_lr",
289
+ type=float,
290
+ default=1e-4,
291
+ help="Learning rate for GeoPath training",
292
+ )
293
+ parser.add_argument(
294
+ "--geopath_weight_decay",
295
+ type=float,
296
+ default=1e-5,
297
+ help="Weight decay for GeoPath training",
298
+ )
299
+ return parser
300
+
301
+
302
+ def flow_network_parser(parser):
303
+ parser.add_argument(
304
+ "--sigma", type=float, default=0.1, help="Sigma parameter for CFM (variance)"
305
+ )
306
+ parser.add_argument(
307
+ "--patience",
308
+ type=int,
309
+ default=5,
310
+ help="Patience for early stopping in CFM training",
311
+ )
312
+ parser.add_argument(
313
+ "--hidden_dims_flow",
314
+ nargs="+",
315
+ type=int,
316
+ default=[64, 64, 64],
317
+ help="Dimensions of hidden layers for CFM training",
318
+ )
319
+ parser.add_argument(
320
+ "--check_val_every_n_epoch",
321
+ type=int,
322
+ default=10,
323
+ help="Check validation every N epochs during CFM training",
324
+ )
325
+ parser.add_argument(
326
+ "--activation_flow",
327
+ type=str,
328
+ default="selu",
329
+ help="Activation function for CFM",
330
+ )
331
+ parser.add_argument(
332
+ "--flow_optimizer",
333
+ type=str,
334
+ default="adamw",
335
+ help="Optimizer for GeoPath training",
336
+ )
337
+ parser.add_argument(
338
+ "--flow_lr",
339
+ type=float,
340
+ default=1e-3,
341
+ help="Learning rate for GeoPath training",
342
+ )
343
+ parser.add_argument(
344
+ "--flow_weight_decay",
345
+ type=float,
346
+ default=1e-5,
347
+ help="Weight decay for GeoPath training",
348
+ )
349
+ return parser
350
+
351
+ def growth_network_parser(parser):
352
+ parser.add_argument(
353
+ "--patience_growth",
354
+ type=int,
355
+ default=5,
356
+ help="Patience for early stopping in CFM training",
357
+ )
358
+ parser.add_argument(
359
+ "--time_growth",
360
+ action=argparse.BooleanOptionalAction,
361
+ default=False,
362
+ help="Use time in GeoPath model",
363
+ )
364
+ parser.add_argument(
365
+ "--hidden_dims_growth",
366
+ nargs="+",
367
+ type=int,
368
+ default=[64, 64, 64],
369
+ help="Dimensions of hidden layers for growth net training",
370
+ )
371
+ parser.add_argument(
372
+ "--activation_growth",
373
+ type=str,
374
+ default="tanh",
375
+ help="Activation function for CFM",
376
+ )
377
+ parser.add_argument(
378
+ "--growth_optimizer",
379
+ type=str,
380
+ default="adamw",
381
+ help="Optimizer for GeoPath training",
382
+ )
383
+ parser.add_argument(
384
+ "--growth_lr",
385
+ type=float,
386
+ default=1e-3,
387
+ help="Learning rate for GeoPath training",
388
+ )
389
+ parser.add_argument(
390
+ "--growth_weight_decay",
391
+ type=float,
392
+ default=1e-5,
393
+ help="Weight decay for GeoPath training",
394
+ )
395
+ parser.add_argument(
396
+ "--lambda_energy",
397
+ type=float,
398
+ default=1.0,
399
+ help="Weight for energy loss",
400
+ )
401
+ parser.add_argument(
402
+ "--lambda_mass",
403
+ type=float,
404
+ default=100.0,
405
+ help="Weight for mass loss",
406
+ )
407
+ parser.add_argument(
408
+ "--lambda_match",
409
+ type=float,
410
+ default=1000.0,
411
+ help="Weight for matching loss",
412
+ )
413
+ parser.add_argument(
414
+ "--lambda_recons",
415
+ type=float,
416
+ default=1.0,
417
+ help="Weight for reconstruction loss",
418
+ )
419
+ return parser
train/train_utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("./BranchSBM")
3
+ import yaml
4
+ import string
5
+ import secrets
6
+ import os
7
+ import torch
8
+ import wandb
9
+ from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
10
+ from torchdyn.core import NeuralODE
11
+ from utils import plot_images_trajectory
12
+ from networks.utils import flow_model_torch_wrapper
13
+
14
+
15
+ def load_config(path):
16
+ with open(path, "r") as file:
17
+ config = yaml.safe_load(file)
18
+ return config
19
+
20
+
21
+ def merge_config(args, config_updates):
22
+ for key, value in config_updates.items():
23
+ if not hasattr(args, key):
24
+ raise ValueError(
25
+ f"Unknown configuration parameter '{key}' found in the config file."
26
+ )
27
+ setattr(args, key, value)
28
+ return args
29
+
30
+
31
+ def generate_group_string(length=16):
32
+ alphabet = string.ascii_letters + string.digits
33
+ return "".join(secrets.choice(alphabet) for _ in range(length))
34
+
35
+
36
+ def dataset_name2datapath(dataset_name, working_dir):
37
+ if dataset_name in ["lidar", "lidarsingle"]:
38
+ return os.path.join(working_dir, "/raid/st512/branchsbm/data", "rainier2-thin.las")
39
+ elif dataset_name == "mouse":
40
+ return os.path.join(working_dir, "/raid/st512/branchsbm/data", "mouse_hematopoiesis.csv")
41
+ elif dataset_name in ["clonidine50D", "clonidine100D", "clonidine150D", "clonidine50Dsingle", "clonidine100Dsingle", "clonidine150Dsingle"]:
42
+ return os.path.join(working_dir, "/raid/st512/branchsbm/data", "pca_and_leiden_labels.csv")
43
+ elif dataset_name in ["trametinib", "trametinibsingle"]:
44
+ return os.path.join(working_dir, "/raid/st512/branchsbm/data", "Trametinib_5.0uM_pca_and_leidenumap_labels.csv")
45
+ else:
46
+ raise ValueError("Dataset not recognized")
47
+
48
+
49
+ def create_callbacks(args, phase, data_type, run_id, datamodule=None):
50
+
51
+ dirpath = os.path.join(
52
+ args.working_dir,
53
+ "checkpoints",
54
+ data_type,
55
+ str(run_id),
56
+ f"{phase}_model",
57
+ )
58
+
59
+ if phase == "geopath":
60
+ early_stop_callback = EarlyStopping(
61
+ monitor="BranchPathNet/val_loss_geopath",
62
+ patience=args.patience_geopath,
63
+ mode="min",
64
+ )
65
+ checkpoint_callback = ModelCheckpoint(
66
+ dirpath=dirpath,
67
+ monitor="BranchPathNet/val_loss_geopath",
68
+ mode="min",
69
+ save_top_k=1,
70
+ )
71
+ callbacks = [checkpoint_callback, early_stop_callback]
72
+ elif phase == "flow":
73
+ early_stop_callback = EarlyStopping(
74
+ monitor="FlowNet/val_loss_cfm",
75
+ patience=args.patience,
76
+ mode="min",
77
+ )
78
+ checkpoint_callback = ModelCheckpoint(
79
+ dirpath=dirpath,
80
+ mode="min",
81
+ save_top_k=1,
82
+ )
83
+ callbacks = [checkpoint_callback, early_stop_callback]
84
+ elif phase == "growth":
85
+ early_stop_callback = EarlyStopping(
86
+ monitor="GrowthNet/val_loss",
87
+ patience=args.patience,
88
+ mode="min",
89
+ )
90
+ checkpoint_callback = ModelCheckpoint(
91
+ dirpath=dirpath,
92
+ mode="min",
93
+ save_top_k=1,
94
+ )
95
+ callbacks = [checkpoint_callback, early_stop_callback]
96
+ elif phase == "joint":
97
+ early_stop_callback = EarlyStopping(
98
+ monitor="JointTrain/val_loss",
99
+ patience=args.patience,
100
+ mode="min",
101
+ )
102
+ checkpoint_callback = ModelCheckpoint(
103
+ dirpath=dirpath,
104
+ mode="min",
105
+ save_top_k=1,
106
+ )
107
+ callbacks = [checkpoint_callback, early_stop_callback]
108
+ else:
109
+ raise ValueError("Unknown phase")
110
+ return callbacks
111
+
112
+
113
+ class PlottingCallback(Callback):
114
+ def __init__(self, plot_interval, datamodule):
115
+ self.plot_interval = plot_interval
116
+ self.datamodule = datamodule
117
+
118
+ def on_train_epoch_end(self, trainer, pl_module):
119
+ epoch = trainer.current_epoch
120
+ pl_module.flow_net.train(mode=False)
121
+ if epoch % self.plot_interval == 0 and epoch != 0:
122
+ node = NeuralODE(
123
+ flow_model_torch_wrapper(pl_module.flow_net).to(self.datamodule.device),
124
+ solver="tsit5",
125
+ sensitivity="adjoint",
126
+ atol=1e-5,
127
+ rtol=1e-5,
128
+ )
129
+
130
+ for mode in ["train", "val"]:
131
+ x0 = getattr(self.datamodule, f"{mode}_x0")
132
+ x0 = x0[0:15]
133
+ fig = self.trajectory_and_plot(x0, node, self.datamodule)
134
+ wandb.log({f"Trajectories {mode.capitalize()}": wandb.Image(fig)})
135
+ pl_module.flow_net.train(mode=True)
136
+
137
+ def trajectory_and_plot(self, x0, node, datamodule):
138
+ selected_images = x0[0:15]
139
+ with torch.no_grad():
140
+ traj = node.trajectory(
141
+ selected_images.to(datamodule.device),
142
+ t_span=torch.linspace(0, 1, 100).to(datamodule.device),
143
+ )
144
+
145
+ traj = traj.transpose(0, 1)
146
+ traj = traj.reshape(*traj.shape[0:2], *datamodule.dim)
147
+
148
+ fig = plot_images_trajectory(
149
+ traj.to(datamodule.device),
150
+ datamodule.vae.to(datamodule.device),
151
+ datamodule.process,
152
+ num_steps=5,
153
+ )
154
+ return fig
utils.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import random
4
+ import matplotlib
5
+ import matplotlib.pyplot as plt
6
+ import math
7
+ import umap
8
+ import scanpy as sc
9
+ from sklearn.decomposition import PCA
10
+
11
+ import ot as pot
12
+ from tqdm import tqdm
13
+ from functools import partial
14
+ from typing import Optional
15
+
16
+ from matplotlib.colors import LinearSegmentedColormap
17
+
18
+
19
+ def set_seed(seed):
20
+ """
21
+ Sets the seed for reproducibility in PyTorch, Numpy, and Python's Random.
22
+
23
+ Parameters:
24
+ seed (int): The seed for the random number generators.
25
+ """
26
+ random.seed(seed) # Python random module
27
+ np.random.seed(seed) # Numpy
28
+ torch.manual_seed(seed) # CPU and GPU (deterministic)
29
+ if torch.cuda.is_available():
30
+ torch.cuda.manual_seed(seed) # CUDA
31
+ torch.cuda.manual_seed_all(seed) # all GPU devices
32
+ torch.backends.cudnn.deterministic = True # CuDNN behavior
33
+ torch.backends.cudnn.benchmark = False
34
+
35
+
36
+ def wasserstein_distance(
37
+ x0: torch.Tensor,
38
+ x1: torch.Tensor,
39
+ method: Optional[str] = None,
40
+ reg: float = 0.05,
41
+ power: int = 1,
42
+ **kwargs,
43
+ ) -> float:
44
+ assert power == 1 or power == 2
45
+ if method == "exact" or method is None:
46
+ ot_fn = pot.emd2
47
+ elif method == "sinkhorn":
48
+ ot_fn = partial(pot.sinkhorn2, reg=reg)
49
+ else:
50
+ raise ValueError(f"Unknown method: {method}")
51
+
52
+ a, b = pot.unif(x0.shape[0]), pot.unif(x1.shape[0])
53
+ if x0.dim() > 2:
54
+ x0 = x0.reshape(x0.shape[0], -1)
55
+ if x1.dim() > 2:
56
+ x1 = x1.reshape(x1.shape[0], -1)
57
+ M = torch.cdist(x0, x1)
58
+ if power == 2:
59
+ M = M**2
60
+ ret = ot_fn(a, b, M.detach().cpu().numpy(), numItermax=1e7)
61
+ if power == 2:
62
+ ret = math.sqrt(ret)
63
+ return ret
64
+
65
+
66
+ def plot_lidar(ax, dataset, xs=None, S=25, branch_idx=None):
67
+ # Combine the dataset and trajectory points for sorting
68
+ combined_points = []
69
+ combined_colors = []
70
+ combined_sizes = []
71
+
72
+
73
+ custom_colors_1 = ["#05009E", "#A19EFF", "#50B2D7"]
74
+ custom_colors_2 = ["#05009E", "#A19EFF", "#D577FF"]
75
+
76
+ custom_cmap_1 = LinearSegmentedColormap.from_list("my_cmap", custom_colors_1)
77
+ custom_cmap_2 = LinearSegmentedColormap.from_list("my_cmap", custom_colors_2)
78
+
79
+ # Normalize the z-coordinates for alpha scaling
80
+ z_coords = (
81
+ dataset[:, 2].numpy() if torch.is_tensor(dataset[:, 2]) else dataset[:, 2]
82
+ )
83
+ z_min, z_max = z_coords.min(), z_coords.max()
84
+ z_norm = (z_coords - z_min) / (z_max - z_min)
85
+
86
+ # Add surface points with a lower z-order
87
+ for i, point in enumerate(dataset):
88
+ grey_value = 0.95 - 0.7 * z_norm[i]
89
+ combined_points.append(point.numpy())
90
+ combined_colors.append(
91
+ (
92
+ grey_value,
93
+ grey_value,
94
+ grey_value,
95
+ 1.0
96
+ )
97
+ ) # Grey color with transparency
98
+ combined_sizes.append(0.1)
99
+
100
+ # Add trajectory points with a higher z-order
101
+ if xs is not None:
102
+ if branch_idx == 0:
103
+ cmap = custom_cmap_1
104
+ else:
105
+ cmap = custom_cmap_2
106
+
107
+ B, T, D = xs.shape
108
+ steps_to_log = np.linspace(0, T - 1, S).astype(int)
109
+ xs = xs.cpu().detach().clone()
110
+ for idx, step in enumerate(steps_to_log):
111
+ for point in xs[:512, step]:
112
+ combined_points.append(
113
+ point.numpy() if torch.is_tensor(point) else point
114
+ )
115
+ combined_colors.append(cmap(idx / (len(steps_to_log) - 1)))
116
+ combined_sizes.append(0.8)
117
+
118
+ # Convert to numpy array for easier manipulation
119
+ combined_points = np.array(combined_points)
120
+ combined_colors = np.array(combined_colors)
121
+ combined_sizes = np.array(combined_sizes)
122
+
123
+ # Sort by z-coordinate (depth)
124
+ sorted_indices = np.argsort(combined_points[:, 2])
125
+ combined_points = combined_points[sorted_indices]
126
+ combined_colors = combined_colors[sorted_indices]
127
+ combined_sizes = combined_sizes[sorted_indices]
128
+
129
+ # Plot the sorted points
130
+ ax.scatter(
131
+ combined_points[:, 0],
132
+ combined_points[:, 1],
133
+ combined_points[:, 2],
134
+ s=combined_sizes,
135
+ c=combined_colors,
136
+ depthshade=True,
137
+ )
138
+
139
+ ax.set_xlim3d(left=-4.8, right=4.8)
140
+ ax.set_ylim3d(bottom=-4.8, top=4.8)
141
+ ax.set_zlim3d(bottom=0.0, top=2.0)
142
+ ax.set_zticks([0, 1.0, 2.0])
143
+ ax.grid(False)
144
+ plt.axis("off")
145
+
146
+ return ax
147
+
148
+
149
+ def plot_images_trajectory(trajectories, vae, processor, num_steps):
150
+
151
+ # Compute trajectories for each image
152
+ t_span = torch.linspace(0, trajectories.shape[1] - 1, num_steps)
153
+ t_span = [int(t) for t in t_span]
154
+ num_images = trajectories.shape[0]
155
+
156
+ # Decode images at each step in each trajectory
157
+ decoded_images = [
158
+ [
159
+ processor.postprocess(
160
+ vae.decode(
161
+ trajectories[i_image, traj_step].unsqueeze(0)
162
+ ).sample.detach()
163
+ )[0]
164
+ for traj_step in t_span
165
+ ]
166
+ for i_image in range(num_images)
167
+ ]
168
+
169
+ # Plotting
170
+ fig, axes = plt.subplots(
171
+ num_images, num_steps, figsize=(num_steps * 2, num_images * 2)
172
+ )
173
+ if num_images == 1:
174
+ axes = [axes] # Ensure axes is iterable
175
+ for img_idx, img_traj in enumerate(decoded_images):
176
+ for step_idx, img in enumerate(img_traj):
177
+ ax = axes[img_idx][step_idx] if num_images > 1 else axes[step_idx]
178
+ if (
179
+ isinstance(img, np.ndarray) and img.shape[0] == 3
180
+ ): # Assuming 3 channels (RGB)
181
+ img = img.transpose(1, 2, 0)
182
+ ax.imshow(img)
183
+ ax.axis("off")
184
+ if img_idx == 0:
185
+ ax.set_title(f"t={t_span[step_idx]/t_span[-1]:.2f}")
186
+ plt.tight_layout()
187
+ return fig
188
+
189
+
190
+ def plot_growth(dataset, growth_nets, xs, output_file='plot.pdf'):
191
+ x0s = [dataset["x0"][0]]
192
+ w0s = [dataset["x0"][1]]
193
+ x1s_list = [[dataset["x1_1"][0]], [dataset["x1_2"][0]]]
194
+ w1s_list = [[dataset["x1_1"][1]], [dataset["x1_2"][1]]]
195
+
196
+
197
+
198
+ plt.show()