sophiat44
commited on
Commit
·
5a87d8d
1
Parent(s):
6501779
model upload
Browse files- branchsbm/.DS_Store +0 -0
- branchsbm/branch_flow_net_train.py +348 -0
- branchsbm/branch_growth_net_train.py +514 -0
- branchsbm/branch_interpolant_train.py +398 -0
- branchsbm/branchsbm.py +109 -0
- branchsbm/ema.py +64 -0
- configs/.DS_Store +0 -0
- configs/experiment/cell_single_branch.yaml +12 -0
- configs/experiment/clonidine_100D.yaml +22 -0
- configs/experiment/clonidine_150D.yaml +22 -0
- configs/experiment/clonidine_50D.yaml +22 -0
- configs/experiment/clonidine_50Dsingle.yaml +22 -0
- configs/experiment/lidar.yaml +14 -0
- configs/experiment/lidar_single.yaml +14 -0
- configs/experiment/mouse.yaml +17 -0
- configs/experiment/trametinib.yaml +22 -0
- configs/experiment/trametinib_single.yaml +22 -0
- dataloaders/.DS_Store +0 -0
- dataloaders/clonidine_data.py +269 -0
- dataloaders/clonidine_single_branch.py +274 -0
- dataloaders/clonidine_v2_data.py +287 -0
- dataloaders/lidar_data.py +532 -0
- dataloaders/lidar_data_single.py +282 -0
- dataloaders/mouse_data.py +438 -0
- dataloaders/three_branch_data.py +310 -0
- dataloaders/trametinib_single.py +279 -0
- losses/.DS_Store +0 -0
- losses/energy_loss.py +73 -0
- networks/.DS_Store +0 -0
- networks/flow_mlp.py +18 -0
- networks/growth_mlp.py +37 -0
- networks/interpolant_mlp.py +35 -0
- networks/mlp_base.py +46 -0
- networks/utils.py +13 -0
- state_costs/.DS_Store +0 -0
- state_costs/land.py +26 -0
- state_costs/metric_factory.py +105 -0
- state_costs/rbf.py +156 -0
- train/.DS_Store +0 -0
- train/main_branches.py +342 -0
- train/parsers.py +419 -0
- train/train_utils.py +154 -0
- utils.py +198 -0
branchsbm/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
branchsbm/branch_flow_net_train.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.append("./BranchSBM")
|
| 4 |
+
import torch
|
| 5 |
+
import wandb
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import pytorch_lightning as pl
|
| 8 |
+
from torch.optim import AdamW
|
| 9 |
+
from torchmetrics.functional import mean_squared_error
|
| 10 |
+
from torchdyn.core import NeuralODE
|
| 11 |
+
from networks.utils import flow_model_torch_wrapper
|
| 12 |
+
from utils import wasserstein_distance, plot_lidar
|
| 13 |
+
from branchsbm.ema import EMA
|
| 14 |
+
|
| 15 |
+
class BranchFlowNetTrainBase(pl.LightningModule):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
flow_matcher,
|
| 19 |
+
flow_nets,
|
| 20 |
+
skipped_time_points=None,
|
| 21 |
+
ot_sampler=None,
|
| 22 |
+
args=None,
|
| 23 |
+
):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.args = args
|
| 26 |
+
|
| 27 |
+
self.flow_matcher = flow_matcher
|
| 28 |
+
self.flow_nets = flow_nets # list of flow networks for each branch
|
| 29 |
+
self.ot_sampler = ot_sampler
|
| 30 |
+
self.skipped_time_points = skipped_time_points
|
| 31 |
+
|
| 32 |
+
self.optimizer_name = args.flow_optimizer
|
| 33 |
+
self.lr = args.flow_lr
|
| 34 |
+
self.weight_decay = args.flow_weight_decay
|
| 35 |
+
self.whiten = args.whiten
|
| 36 |
+
self.working_dir = args.working_dir
|
| 37 |
+
|
| 38 |
+
#branching
|
| 39 |
+
self.branches = len(flow_nets)
|
| 40 |
+
|
| 41 |
+
def forward(self, t, xt, branch_idx):
|
| 42 |
+
# output velocity given branch_idx
|
| 43 |
+
return self.flow_nets[branch_idx](t, xt)
|
| 44 |
+
|
| 45 |
+
def _compute_loss(self, main_batch):
|
| 46 |
+
|
| 47 |
+
x0s = [main_batch["x0"][0]]
|
| 48 |
+
w0s = [main_batch["x0"][1]]
|
| 49 |
+
|
| 50 |
+
x1s_list = []
|
| 51 |
+
w1s_list = []
|
| 52 |
+
|
| 53 |
+
if self.branches > 1:
|
| 54 |
+
for i in range(self.branches):
|
| 55 |
+
x1s_list.append([main_batch[f"x1_{i+1}"][0]])
|
| 56 |
+
w1s_list.append([main_batch[f"x1_{i+1}"][1]])
|
| 57 |
+
else:
|
| 58 |
+
x1s_list.append([main_batch["x1"][0]])
|
| 59 |
+
w1s_list.append([main_batch["x1"][1]])
|
| 60 |
+
|
| 61 |
+
assert len(x1s_list) == self.branches, "Mismatch between x1s_list and expected branches"
|
| 62 |
+
|
| 63 |
+
loss = 0
|
| 64 |
+
for branch_idx in range(self.branches):
|
| 65 |
+
ts, xts, uts = self._process_flow(x0s, x1s_list[branch_idx], branch_idx)
|
| 66 |
+
|
| 67 |
+
t = torch.cat(ts)
|
| 68 |
+
xt = torch.cat(xts)
|
| 69 |
+
ut = torch.cat(uts)
|
| 70 |
+
vt = self(t[:, None], xt, branch_idx)
|
| 71 |
+
|
| 72 |
+
loss += mean_squared_error(vt, ut)
|
| 73 |
+
|
| 74 |
+
return loss
|
| 75 |
+
|
| 76 |
+
def _process_flow(self, x0s, x1s, branch_idx):
|
| 77 |
+
ts, xts, uts = [], [], []
|
| 78 |
+
t_start = self.timesteps[0]
|
| 79 |
+
|
| 80 |
+
for i, (x0, x1) in enumerate(zip(x0s, x1s)):
|
| 81 |
+
|
| 82 |
+
x0, x1 = torch.squeeze(x0), torch.squeeze(x1)
|
| 83 |
+
|
| 84 |
+
if self.ot_sampler is not None:
|
| 85 |
+
x0, x1 = self.ot_sampler.sample_plan(
|
| 86 |
+
x0,
|
| 87 |
+
x1,
|
| 88 |
+
replace=True,
|
| 89 |
+
)
|
| 90 |
+
if self.skipped_time_points and i + 1 >= self.skipped_time_points[0]:
|
| 91 |
+
t_start_next = self.timesteps[i + 2]
|
| 92 |
+
else:
|
| 93 |
+
t_start_next = self.timesteps[i + 1]
|
| 94 |
+
|
| 95 |
+
# edit to sample from correct flow matcher
|
| 96 |
+
t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(
|
| 97 |
+
x0, x1, t_start, t_start_next, branch_idx
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
ts.append(t)
|
| 101 |
+
|
| 102 |
+
xts.append(xt)
|
| 103 |
+
uts.append(ut)
|
| 104 |
+
t_start = t_start_next
|
| 105 |
+
return ts, xts, uts
|
| 106 |
+
|
| 107 |
+
def training_step(self, batch, batch_idx):
|
| 108 |
+
if self.args.data_type in ["scrna", "tahoe"]:
|
| 109 |
+
main_batch = batch[0]["train_samples"][0]
|
| 110 |
+
else:
|
| 111 |
+
main_batch = batch["train_samples"][0]
|
| 112 |
+
|
| 113 |
+
print("Main batch length")
|
| 114 |
+
print(len(main_batch["x0"]))
|
| 115 |
+
self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist()
|
| 116 |
+
loss = self._compute_loss(main_batch)
|
| 117 |
+
if self.flow_matcher.alpha != 0:
|
| 118 |
+
self.log(
|
| 119 |
+
"FlowNet/mean_geopath_cfm",
|
| 120 |
+
(self.flow_matcher.geopath_net_output.abs().mean()),
|
| 121 |
+
on_step=False,
|
| 122 |
+
on_epoch=True,
|
| 123 |
+
prog_bar=True,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
self.log(
|
| 127 |
+
"FlowNet/train_loss_cfm",
|
| 128 |
+
loss,
|
| 129 |
+
on_step=False,
|
| 130 |
+
on_epoch=True,
|
| 131 |
+
prog_bar=True,
|
| 132 |
+
logger=True,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
return loss
|
| 137 |
+
|
| 138 |
+
def validation_step(self, batch, batch_idx):
|
| 139 |
+
if self.args.data_type in ["scrna", "tahoe"]:
|
| 140 |
+
main_batch = batch[0]["val_samples"][0]
|
| 141 |
+
else:
|
| 142 |
+
main_batch = batch["val_samples"][0]
|
| 143 |
+
|
| 144 |
+
self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist()
|
| 145 |
+
val_loss = self._compute_loss(main_batch)
|
| 146 |
+
self.log(
|
| 147 |
+
"FlowNet/val_loss_cfm",
|
| 148 |
+
val_loss,
|
| 149 |
+
on_step=False,
|
| 150 |
+
on_epoch=True,
|
| 151 |
+
prog_bar=True,
|
| 152 |
+
logger=True,
|
| 153 |
+
)
|
| 154 |
+
return val_loss
|
| 155 |
+
|
| 156 |
+
def optimizer_step(self, *args, **kwargs):
|
| 157 |
+
super().optimizer_step(*args, **kwargs)
|
| 158 |
+
|
| 159 |
+
for net in self.flow_nets:
|
| 160 |
+
if isinstance(net, EMA):
|
| 161 |
+
net.update_ema()
|
| 162 |
+
|
| 163 |
+
def configure_optimizers(self):
|
| 164 |
+
if self.optimizer_name == "adamw":
|
| 165 |
+
optimizer = AdamW(
|
| 166 |
+
self.parameters(),
|
| 167 |
+
lr=self.lr,
|
| 168 |
+
weight_decay=self.weight_decay,
|
| 169 |
+
)
|
| 170 |
+
elif self.optimizer_name == "adam":
|
| 171 |
+
optimizer = torch.optim.Adam(
|
| 172 |
+
self.parameters(),
|
| 173 |
+
lr=self.lr,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
return optimizer
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class FlowNetTrainTrajectory(BranchFlowNetTrainBase):
|
| 180 |
+
def test_step(self, batch, batch_idx):
|
| 181 |
+
data_type = self.args.data_type
|
| 182 |
+
node = NeuralODE(
|
| 183 |
+
flow_model_torch_wrapper(self.flow_nets),
|
| 184 |
+
solver="euler",
|
| 185 |
+
sensitivity="adjoint",
|
| 186 |
+
atol=1e-5,
|
| 187 |
+
rtol=1e-5,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
t_exclude = self.skipped_time_points[0] if self.skipped_time_points else None
|
| 191 |
+
if t_exclude is not None:
|
| 192 |
+
traj = node.trajectory(
|
| 193 |
+
batch[t_exclude - 1],
|
| 194 |
+
t_span=torch.linspace(
|
| 195 |
+
self.timesteps[t_exclude - 1], self.timesteps[t_exclude], 101
|
| 196 |
+
),
|
| 197 |
+
)
|
| 198 |
+
X_mid_pred = traj[-1]
|
| 199 |
+
traj = node.trajectory(
|
| 200 |
+
batch[t_exclude - 1],
|
| 201 |
+
t_span=torch.linspace(
|
| 202 |
+
self.timesteps[t_exclude - 1],
|
| 203 |
+
self.timesteps[t_exclude + 1],
|
| 204 |
+
101,
|
| 205 |
+
),
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
EMD = wasserstein_distance(X_mid_pred, batch[t_exclude], p=1)
|
| 209 |
+
self.final_EMD = EMD
|
| 210 |
+
|
| 211 |
+
self.log("test_EMD", EMD, on_step=False, on_epoch=True, prog_bar=True)
|
| 212 |
+
|
| 213 |
+
class FlowNetTrainCell(BranchFlowNetTrainBase):
|
| 214 |
+
def test_step(self, batch, batch_idx):
|
| 215 |
+
x0 = batch[0]["test_samples"][0]["x0"][0] # [B, D]
|
| 216 |
+
dataset_points = batch[0]["test_samples"][0]["dataset"][0] # full dataset, [N, D]
|
| 217 |
+
t_span = torch.linspace(0, 1, 101)
|
| 218 |
+
|
| 219 |
+
all_trajs = []
|
| 220 |
+
|
| 221 |
+
for i, flow_net in enumerate(self.flow_nets):
|
| 222 |
+
node = NeuralODE(
|
| 223 |
+
flow_model_torch_wrapper(flow_net),
|
| 224 |
+
solver="euler",
|
| 225 |
+
sensitivity="adjoint",
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
with torch.no_grad():
|
| 229 |
+
traj = node.trajectory(x0, t_span).cpu() # [T, B, D]
|
| 230 |
+
|
| 231 |
+
if self.whiten:
|
| 232 |
+
traj_shape = traj.shape
|
| 233 |
+
traj = traj.reshape(-1, traj.shape[-1])
|
| 234 |
+
traj = self.trainer.datamodule.scaler.inverse_transform(
|
| 235 |
+
traj.cpu().detach().numpy()
|
| 236 |
+
).reshape(traj_shape)
|
| 237 |
+
dataset_points = self.trainer.datamodule.scaler.inverse_transform(
|
| 238 |
+
dataset_points.cpu().detach().numpy()
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
traj = torch.tensor(traj)
|
| 242 |
+
traj = torch.transpose(traj, 0, 1) # [B, T, D]
|
| 243 |
+
all_trajs.append(traj)
|
| 244 |
+
|
| 245 |
+
dataset_2d = dataset_points[:, :2] if isinstance(dataset_points, torch.Tensor) else dataset_points[:, :2]
|
| 246 |
+
|
| 247 |
+
# ===== Plot all 2D trajectories together with dataset and start/end points =====
|
| 248 |
+
fig, ax = plt.subplots(figsize=(6, 5))
|
| 249 |
+
dataset_2d = dataset_2d.cpu().numpy()
|
| 250 |
+
ax.scatter(dataset_2d[:, 0], dataset_2d[:, 1], c="gray", s=1, alpha=0.5, label="Dataset", zorder=1)
|
| 251 |
+
for traj in all_trajs:
|
| 252 |
+
traj_2d = traj[..., :2] # [B, T, 2]
|
| 253 |
+
for i in range(traj_2d.shape[0]):
|
| 254 |
+
ax.plot(traj_2d[i, :, 0], traj_2d[i, :, 1], alpha=0.8, zorder=2)
|
| 255 |
+
ax.scatter(traj_2d[i, 0, 0], traj_2d[i, 0, 1], c='green', s=10, label="t=0" if i == 0 else "", zorder=3)
|
| 256 |
+
ax.scatter(traj_2d[i, -1, 0], traj_2d[i, -1, 1], c='red', s=10, label="t=1" if i == 0 else "", zorder=3)
|
| 257 |
+
|
| 258 |
+
ax.set_title("All Branch Trajectories (2D) with Dataset")
|
| 259 |
+
ax.set_xlabel("x")
|
| 260 |
+
ax.set_ylabel("y")
|
| 261 |
+
plt.axis("equal")
|
| 262 |
+
handles, labels = ax.get_legend_handles_labels()
|
| 263 |
+
if labels:
|
| 264 |
+
ax.legend()
|
| 265 |
+
|
| 266 |
+
save_path = f'./figures/{self.args.data_name}'
|
| 267 |
+
|
| 268 |
+
os.makedirs(save_path, exist_ok=True)
|
| 269 |
+
plt.savefig(f'{save_path}/{self.args.data_name}_all_branches.png', dpi=300)
|
| 270 |
+
plt.close()
|
| 271 |
+
|
| 272 |
+
# ===== Plot each 2D trajectory separately with dataset and endpoints =====
|
| 273 |
+
for i, traj in enumerate(all_trajs):
|
| 274 |
+
traj_2d = traj[..., :2]
|
| 275 |
+
fig, ax = plt.subplots(figsize=(6, 5))
|
| 276 |
+
ax.scatter(dataset_2d[:, 0], dataset_2d[:, 1], c="gray", s=1, alpha=0.5, label="Dataset", zorder=1)
|
| 277 |
+
for j in range(traj_2d.shape[0]):
|
| 278 |
+
ax.plot(traj_2d[j, :, 0], traj_2d[j, :, 1], alpha=0.9, zorder=2)
|
| 279 |
+
ax.scatter(traj_2d[j, 0, 0], traj_2d[j, 0, 1], c='green', s=12, label="t=0" if j == 0 else "", zorder=3)
|
| 280 |
+
ax.scatter(traj_2d[j, -1, 0], traj_2d[j, -1, 1], c='red', s=12, label="t=1" if j == 0 else "", zorder=3)
|
| 281 |
+
|
| 282 |
+
ax.set_title(f"Branch {i + 1} Trajectories (2D) with Dataset")
|
| 283 |
+
ax.set_xlabel("x")
|
| 284 |
+
ax.set_ylabel("y")
|
| 285 |
+
plt.axis("equal")
|
| 286 |
+
handles, labels = ax.get_legend_handles_labels()
|
| 287 |
+
if labels:
|
| 288 |
+
ax.legend()
|
| 289 |
+
plt.savefig(f'{save_path}/{self.args.data_name}_branch_{i + 1}.png', dpi=300)
|
| 290 |
+
plt.close()
|
| 291 |
+
|
| 292 |
+
class FlowNetTrainLidar(BranchFlowNetTrainBase):
|
| 293 |
+
def test_step(self, batch, batch_idx):
|
| 294 |
+
main_batch = batch["test_samples"][0]
|
| 295 |
+
metric_batch = batch["metric_samples"][0]
|
| 296 |
+
|
| 297 |
+
x0 = main_batch["x0"][0] # [B, D]
|
| 298 |
+
cloud_points = main_batch["dataset"][0] # full dataset, [N, D]
|
| 299 |
+
t_span = torch.linspace(0, 1, 101)
|
| 300 |
+
|
| 301 |
+
all_trajs = []
|
| 302 |
+
|
| 303 |
+
for i, flow_net in enumerate(self.flow_nets):
|
| 304 |
+
node = NeuralODE(
|
| 305 |
+
flow_model_torch_wrapper(flow_net),
|
| 306 |
+
solver="euler",
|
| 307 |
+
sensitivity="adjoint",
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
with torch.no_grad():
|
| 311 |
+
traj = node.trajectory(x0, t_span).cpu() # [T, B, D]
|
| 312 |
+
|
| 313 |
+
if self.whiten:
|
| 314 |
+
traj_shape = traj.shape
|
| 315 |
+
traj = traj.reshape(-1, 3)
|
| 316 |
+
traj = self.trainer.datamodule.scaler.inverse_transform(
|
| 317 |
+
traj.cpu().detach().numpy()
|
| 318 |
+
).reshape(traj_shape)
|
| 319 |
+
|
| 320 |
+
traj = torch.tensor(traj)
|
| 321 |
+
traj = torch.transpose(traj, 0, 1) # [B, T, D]
|
| 322 |
+
all_trajs.append(traj)
|
| 323 |
+
|
| 324 |
+
# Inverse-transform the point cloud once
|
| 325 |
+
if self.whiten:
|
| 326 |
+
cloud_points = torch.tensor(
|
| 327 |
+
self.trainer.datamodule.scaler.inverse_transform(
|
| 328 |
+
cloud_points.cpu().detach().numpy()
|
| 329 |
+
)
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# ===== Plot all trajectories together =====
|
| 333 |
+
fig = plt.figure(figsize=(6, 5))
|
| 334 |
+
ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
|
| 335 |
+
ax.view_init(elev=30, azim=-115, roll=0)
|
| 336 |
+
for i, traj in enumerate(all_trajs):
|
| 337 |
+
plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
|
| 338 |
+
plt.savefig('./figures/lidar/lidar_all_branches.png', dpi=300)
|
| 339 |
+
plt.close()
|
| 340 |
+
|
| 341 |
+
# ===== Plot each trajectory separately =====
|
| 342 |
+
for i, traj in enumerate(all_trajs):
|
| 343 |
+
fig = plt.figure(figsize=(6, 5))
|
| 344 |
+
ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
|
| 345 |
+
ax.view_init(elev=30, azim=-115, roll=0)
|
| 346 |
+
plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
|
| 347 |
+
plt.savefig(f'./figures/lidar/lidar_branch_{i + 1}.png', dpi=300)
|
| 348 |
+
plt.close()
|
branchsbm/branch_growth_net_train.py
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.append("./BranchSBM")
|
| 4 |
+
import torch
|
| 5 |
+
import wandb
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import pytorch_lightning as pl
|
| 8 |
+
from torch.optim import AdamW
|
| 9 |
+
from torchmetrics.functional import mean_squared_error
|
| 10 |
+
from torchdyn.core import NeuralODE
|
| 11 |
+
import numpy as np
|
| 12 |
+
import lpips
|
| 13 |
+
from networks.utils import flow_model_torch_wrapper
|
| 14 |
+
from utils import wasserstein_distance, plot_lidar
|
| 15 |
+
from branchsbm.ema import EMA
|
| 16 |
+
from torchdiffeq import odeint as odeint2
|
| 17 |
+
from losses.energy_loss import EnergySolver, ReconsLoss
|
| 18 |
+
|
| 19 |
+
class GrowthNetTrain(pl.LightningModule):
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
flow_nets,
|
| 23 |
+
growth_nets,
|
| 24 |
+
skipped_time_points=None,
|
| 25 |
+
ot_sampler=None,
|
| 26 |
+
args=None,
|
| 27 |
+
|
| 28 |
+
state_cost=None,
|
| 29 |
+
data_manifold_metric=None,
|
| 30 |
+
|
| 31 |
+
joint = False
|
| 32 |
+
):
|
| 33 |
+
super().__init__()
|
| 34 |
+
#self.save_hyperparameters()
|
| 35 |
+
self.flow_nets = flow_nets
|
| 36 |
+
|
| 37 |
+
if not joint:
|
| 38 |
+
for param in self.flow_nets.parameters():
|
| 39 |
+
param.requires_grad = False
|
| 40 |
+
|
| 41 |
+
self.growth_nets = growth_nets # list of growth networks for each branch
|
| 42 |
+
|
| 43 |
+
self.ot_sampler = ot_sampler
|
| 44 |
+
self.skipped_time_points = skipped_time_points
|
| 45 |
+
|
| 46 |
+
self.optimizer_name = args.growth_optimizer
|
| 47 |
+
self.lr = args.growth_lr
|
| 48 |
+
self.weight_decay = args.growth_weight_decay
|
| 49 |
+
self.whiten = args.whiten
|
| 50 |
+
self.working_dir = args.working_dir
|
| 51 |
+
|
| 52 |
+
self.args = args
|
| 53 |
+
|
| 54 |
+
#branching
|
| 55 |
+
self.state_cost = state_cost
|
| 56 |
+
self.data_manifold_metric = data_manifold_metric
|
| 57 |
+
self.branches = len(growth_nets)
|
| 58 |
+
self.metric_clusters = args.metric_clusters
|
| 59 |
+
|
| 60 |
+
self.recons_loss = ReconsLoss()
|
| 61 |
+
|
| 62 |
+
# loss weights
|
| 63 |
+
self.lambda_energy = args.lambda_energy
|
| 64 |
+
self.lambda_mass = args.lambda_mass
|
| 65 |
+
self.lambda_match = args.lambda_match
|
| 66 |
+
self.lambda_recons = args.lambda_recons
|
| 67 |
+
|
| 68 |
+
self.joint = joint
|
| 69 |
+
|
| 70 |
+
def forward(self, t, xt, branch_idx):
|
| 71 |
+
# output growth rate given branch_idx
|
| 72 |
+
return self.growth_nets[branch_idx](t, xt)
|
| 73 |
+
|
| 74 |
+
def _compute_loss(self, main_batch, metric_samples_batch=None, validation=False):
|
| 75 |
+
x0s = main_batch["x0"][0]
|
| 76 |
+
w0s = main_batch["x0"][1]
|
| 77 |
+
x1s_list = []
|
| 78 |
+
w1s_list = []
|
| 79 |
+
|
| 80 |
+
if self.branches > 1:
|
| 81 |
+
for i in range(self.branches):
|
| 82 |
+
x1s_list.append([main_batch[f"x1_{i+1}"][0]])
|
| 83 |
+
w1s_list.append([main_batch[f"x1_{i+1}"][1]])
|
| 84 |
+
else:
|
| 85 |
+
x1s_list.append([main_batch["x1"][0]])
|
| 86 |
+
w1s_list.append([main_batch["x1"][1]])
|
| 87 |
+
|
| 88 |
+
if self.args.manifold:
|
| 89 |
+
#changed
|
| 90 |
+
if self.metric_clusters == 4:
|
| 91 |
+
branch_sample_pairs = [
|
| 92 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 93 |
+
(metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
|
| 94 |
+
(metric_samples_batch[0], metric_samples_batch[3]),
|
| 95 |
+
]
|
| 96 |
+
elif self.metric_clusters == 3:
|
| 97 |
+
branch_sample_pairs = [
|
| 98 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 99 |
+
(metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
|
| 100 |
+
]
|
| 101 |
+
elif self.metric_clusters == 2 and self.branches == 2:
|
| 102 |
+
branch_sample_pairs = [
|
| 103 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 104 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_2 (branch 2)
|
| 105 |
+
]
|
| 106 |
+
else:
|
| 107 |
+
branch_sample_pairs = [
|
| 108 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
batch_size = x0s.shape[0]
|
| 112 |
+
|
| 113 |
+
assert len(x1s_list) == self.branches, "Mismatch between x1s_list and expected branches"
|
| 114 |
+
|
| 115 |
+
energy_loss = [0.] * self.branches
|
| 116 |
+
mass_loss = 0.
|
| 117 |
+
neg_weight_penalty = 0.
|
| 118 |
+
match_loss = [0.] * self.branches
|
| 119 |
+
recons_loss = [0.] * self.branches
|
| 120 |
+
|
| 121 |
+
dtype = x0s[0].dtype
|
| 122 |
+
#w0s = torch.zeros((batch_size, 1), dtype=dtype)
|
| 123 |
+
m0s = torch.zeros_like(w0s, dtype=dtype)
|
| 124 |
+
start_state = (x0s, w0s, m0s)
|
| 125 |
+
|
| 126 |
+
xt = [x0s.clone() for _ in range(self.branches)]
|
| 127 |
+
w0_branch = torch.zeros_like(w0s, dtype=dtype)
|
| 128 |
+
w0_branches = []
|
| 129 |
+
w0_branches.append(w0s)
|
| 130 |
+
for _ in range(self.branches - 1):
|
| 131 |
+
w0_branches.append(w0_branch)
|
| 132 |
+
#w0_branches = [w0_branch.clone() for _ in range(self.branches - 1)]
|
| 133 |
+
wt = w0_branches
|
| 134 |
+
|
| 135 |
+
mt = [m0s.clone() for _ in range(self.branches)]
|
| 136 |
+
|
| 137 |
+
# loop through timesteps
|
| 138 |
+
for s, t in zip(self.timesteps[:-1], self.timesteps[1:]):
|
| 139 |
+
time = torch.Tensor([s, t])
|
| 140 |
+
|
| 141 |
+
total_w_t = 0
|
| 142 |
+
# loop through branches
|
| 143 |
+
for i in range(self.branches):
|
| 144 |
+
|
| 145 |
+
if self.args.manifold:
|
| 146 |
+
start_samples, end_samples = branch_sample_pairs[i]
|
| 147 |
+
samples = torch.cat([start_samples, end_samples], dim=0)
|
| 148 |
+
|
| 149 |
+
# initialize weight and energy
|
| 150 |
+
start_state = (xt[i], wt[i], mt[i])
|
| 151 |
+
|
| 152 |
+
# loop over timesteps
|
| 153 |
+
xt_next, wt_next, mt_next = self.take_step(time, start_state, i, samples)
|
| 154 |
+
|
| 155 |
+
# placeholders for next state
|
| 156 |
+
xt_last = xt_next[-1]
|
| 157 |
+
wt_last = wt_next[-1]
|
| 158 |
+
mt_last = mt_next[-1]
|
| 159 |
+
|
| 160 |
+
total_w_t += wt_last
|
| 161 |
+
|
| 162 |
+
energy_loss[i] += (mt_last - mt[i])
|
| 163 |
+
neg_weight_penalty += torch.relu(-wt_last).sum()
|
| 164 |
+
|
| 165 |
+
# update branch state
|
| 166 |
+
xt[i] = xt_last.clone().detach()
|
| 167 |
+
wt[i] = wt_last.clone().detach()
|
| 168 |
+
mt[i] = mt_last.clone().detach()
|
| 169 |
+
|
| 170 |
+
# calculate mass loss from all branches
|
| 171 |
+
target = torch.ones_like(total_w_t)
|
| 172 |
+
mass_loss += mean_squared_error(total_w_t, target)
|
| 173 |
+
|
| 174 |
+
# calculate loss that matches final weights
|
| 175 |
+
for i in range(self.branches):
|
| 176 |
+
match_loss[i] = mean_squared_error(wt[i], w1s_list[i][0])
|
| 177 |
+
# compute reconstruction loss
|
| 178 |
+
recons_loss[i] = self.recons_loss(xt[i], x1s_list[i][0])
|
| 179 |
+
|
| 180 |
+
# average across times
|
| 181 |
+
mass_loss = mass_loss / len(self.timesteps)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# mean across branches
|
| 185 |
+
energy_loss = torch.mean(torch.stack(energy_loss))
|
| 186 |
+
match_loss = torch.mean(torch.stack(match_loss))
|
| 187 |
+
recons_loss = torch.mean(torch.stack(recons_loss))
|
| 188 |
+
|
| 189 |
+
loss = (self.lambda_energy * energy_loss) + (self.lambda_mass * (mass_loss + neg_weight_penalty)) + (self.lambda_match * match_loss) \
|
| 190 |
+
+ (self.lambda_recons * recons_loss)
|
| 191 |
+
|
| 192 |
+
if self.joint:
|
| 193 |
+
if validation:
|
| 194 |
+
self.log("JointTrain/val_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 195 |
+
self.log("JointTrain/val_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True)
|
| 196 |
+
self.log("JointTrain/val_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 197 |
+
self.log("JointTrain/val_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 198 |
+
self.log("JointTrain/val_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 199 |
+
else:
|
| 200 |
+
self.log("JointTrain/train_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 201 |
+
self.log("JointTrain/train_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True)
|
| 202 |
+
self.log("JointTrain/train_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 203 |
+
self.log("JointTrain/train_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 204 |
+
self.log("JointTrain/train_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 205 |
+
else:
|
| 206 |
+
if validation:
|
| 207 |
+
self.log("GrowthNet/val_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 208 |
+
self.log("GrowthNet/val_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True)
|
| 209 |
+
self.log("GrowthNet/val_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 210 |
+
self.log("GrowthNet/val_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 211 |
+
self.log("GrowthNet/val_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 212 |
+
else:
|
| 213 |
+
self.log("GrowthNet/train_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 214 |
+
self.log("GrowthNet/train_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True)
|
| 215 |
+
self.log("GrowthNet/train_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 216 |
+
self.log("GrowthNet/train_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 217 |
+
self.log("GrowthNet/train_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 218 |
+
|
| 219 |
+
return loss
|
| 220 |
+
|
| 221 |
+
def take_step(self, t, start_state, branch_idx, samples=None):
|
| 222 |
+
|
| 223 |
+
flow_net = self.flow_nets[branch_idx]
|
| 224 |
+
growth_net = self.growth_nets[branch_idx]
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
x_t, w_t, m_t = odeint2(EnergySolver(flow_net, growth_net, self.state_cost, self.data_manifold_metric, samples), start_state, t, options=dict(step_size=0.1),method='euler')
|
| 228 |
+
|
| 229 |
+
return x_t, w_t, m_t
|
| 230 |
+
|
| 231 |
+
def training_step(self, batch, batch_idx):
|
| 232 |
+
if self.args.data_type in ["scrna", "tahoe"]:
|
| 233 |
+
main_batch = batch[0]["train_samples"][0]
|
| 234 |
+
metric_batch = batch[0]["metric_samples"][0]
|
| 235 |
+
else:
|
| 236 |
+
main_batch = batch["train_samples"][0]
|
| 237 |
+
metric_batch = batch["metric_samples"][0]
|
| 238 |
+
|
| 239 |
+
self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist()
|
| 240 |
+
loss = self._compute_loss(main_batch, metric_batch, validation=False)
|
| 241 |
+
|
| 242 |
+
if self.joint:
|
| 243 |
+
self.log(
|
| 244 |
+
"JointTrain/train_loss",
|
| 245 |
+
loss,
|
| 246 |
+
on_step=False,
|
| 247 |
+
on_epoch=True,
|
| 248 |
+
prog_bar=True,
|
| 249 |
+
logger=True,
|
| 250 |
+
)
|
| 251 |
+
else:
|
| 252 |
+
self.log(
|
| 253 |
+
"GrowthNet/train_loss",
|
| 254 |
+
loss,
|
| 255 |
+
on_step=False,
|
| 256 |
+
on_epoch=True,
|
| 257 |
+
prog_bar=True,
|
| 258 |
+
logger=True,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
return loss
|
| 262 |
+
|
| 263 |
+
def validation_step(self, batch, batch_idx):
|
| 264 |
+
if self.args.data_type in ["scrna", "tahoe"]:
|
| 265 |
+
main_batch = batch[0]["val_samples"][0]
|
| 266 |
+
metric_batch = batch[0]["metric_samples"][0]
|
| 267 |
+
else:
|
| 268 |
+
main_batch = batch["val_samples"][0]
|
| 269 |
+
metric_batch = batch["metric_samples"][0]
|
| 270 |
+
|
| 271 |
+
self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist()
|
| 272 |
+
val_loss = self._compute_loss(main_batch, metric_batch, validation=True)
|
| 273 |
+
|
| 274 |
+
if self.joint:
|
| 275 |
+
self.log(
|
| 276 |
+
"JointTrain/val_loss",
|
| 277 |
+
val_loss,
|
| 278 |
+
on_step=False,
|
| 279 |
+
on_epoch=True,
|
| 280 |
+
prog_bar=True,
|
| 281 |
+
logger=True,
|
| 282 |
+
)
|
| 283 |
+
else:
|
| 284 |
+
self.log(
|
| 285 |
+
"GrowthNet/val_loss",
|
| 286 |
+
val_loss,
|
| 287 |
+
on_step=False,
|
| 288 |
+
on_epoch=True,
|
| 289 |
+
prog_bar=True,
|
| 290 |
+
logger=True,
|
| 291 |
+
)
|
| 292 |
+
return val_loss
|
| 293 |
+
|
| 294 |
+
def optimizer_step(self, *args, **kwargs):
|
| 295 |
+
super().optimizer_step(*args, **kwargs)
|
| 296 |
+
for net in self.growth_nets:
|
| 297 |
+
if isinstance(net, EMA):
|
| 298 |
+
net.update_ema()
|
| 299 |
+
if self.joint:
|
| 300 |
+
for net in self.flow_nets:
|
| 301 |
+
if isinstance(net, EMA):
|
| 302 |
+
net.update_ema()
|
| 303 |
+
|
| 304 |
+
def configure_optimizers(self):
|
| 305 |
+
params = []
|
| 306 |
+
for net in self.growth_nets:
|
| 307 |
+
params += list(net.parameters())
|
| 308 |
+
|
| 309 |
+
if self.joint:
|
| 310 |
+
for net in self.flow_nets:
|
| 311 |
+
params += list(net.parameters())
|
| 312 |
+
|
| 313 |
+
if self.optimizer_name == "adamw":
|
| 314 |
+
optimizer = AdamW(
|
| 315 |
+
params,
|
| 316 |
+
lr=self.lr,
|
| 317 |
+
weight_decay=self.weight_decay,
|
| 318 |
+
)
|
| 319 |
+
elif self.optimizer_name == "adam":
|
| 320 |
+
optimizer = torch.optim.Adam(
|
| 321 |
+
params,
|
| 322 |
+
lr=self.lr,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
return optimizer
|
| 326 |
+
|
| 327 |
+
@torch.no_grad()
|
| 328 |
+
def _plot_mass_and_energy(self, main_batch, metric_samples_batch=None, save_dir="./figures"):
|
| 329 |
+
x0s = main_batch["x0"][0]
|
| 330 |
+
w0s = main_batch["x0"][1]
|
| 331 |
+
|
| 332 |
+
if self.args.manifold:
|
| 333 |
+
if self.metric_clusters == 4:
|
| 334 |
+
branch_sample_pairs = [
|
| 335 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 336 |
+
(metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
|
| 337 |
+
(metric_samples_batch[0], metric_samples_batch[3]),
|
| 338 |
+
]
|
| 339 |
+
elif self.metric_clusters == 3:
|
| 340 |
+
branch_sample_pairs = [
|
| 341 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 342 |
+
(metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
|
| 343 |
+
]
|
| 344 |
+
elif self.metric_clusters == 2 and self.branches == 2:
|
| 345 |
+
branch_sample_pairs = [
|
| 346 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 347 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_2 (branch 2)
|
| 348 |
+
]
|
| 349 |
+
else:
|
| 350 |
+
branch_sample_pairs = [
|
| 351 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 352 |
+
]
|
| 353 |
+
|
| 354 |
+
batch_size = x0s.shape[0]
|
| 355 |
+
dtype = x0s[0].dtype
|
| 356 |
+
|
| 357 |
+
m0s = torch.zeros_like(w0s, dtype=dtype)
|
| 358 |
+
xt = [x0s.clone() for _ in range(self.branches)]
|
| 359 |
+
|
| 360 |
+
w0_branch = torch.zeros_like(w0s, dtype=dtype)
|
| 361 |
+
w0_branches = []
|
| 362 |
+
w0_branches.append(w0s)
|
| 363 |
+
for _ in range(self.branches - 1):
|
| 364 |
+
w0_branches.append(w0_branch)
|
| 365 |
+
|
| 366 |
+
wt = w0_branches
|
| 367 |
+
mt = [m0s.clone() for _ in range(self.branches)]
|
| 368 |
+
|
| 369 |
+
time_points = []
|
| 370 |
+
mass_over_time = [[] for _ in range(self.branches)]
|
| 371 |
+
energy_over_time = [[] for _ in range(self.branches)]
|
| 372 |
+
|
| 373 |
+
t_span = torch.linspace(0, 1, 101)
|
| 374 |
+
for s, t in zip(t_span[:-1], t_span[1:]):
|
| 375 |
+
time_points.append(t.item())
|
| 376 |
+
time = torch.Tensor([s, t])
|
| 377 |
+
|
| 378 |
+
for i in range(self.branches):
|
| 379 |
+
if self.args.manifold:
|
| 380 |
+
start_samples, end_samples = branch_sample_pairs[i]
|
| 381 |
+
samples = torch.cat([start_samples, end_samples], dim=0)
|
| 382 |
+
else:
|
| 383 |
+
samples = None
|
| 384 |
+
|
| 385 |
+
start_state = (xt[i], wt[i], mt[i])
|
| 386 |
+
xt_next, wt_next, mt_next = self.take_step(time, start_state, i, samples)
|
| 387 |
+
|
| 388 |
+
xt[i] = xt_next[-1].clone().detach()
|
| 389 |
+
wt[i] = wt_next[-1].clone().detach()
|
| 390 |
+
mt[i] = mt_next[-1].clone().detach()
|
| 391 |
+
|
| 392 |
+
mass_over_time[i].append(wt[i].mean().item())
|
| 393 |
+
energy_over_time[i].append(mt[i].mean().item())
|
| 394 |
+
|
| 395 |
+
os.makedirs(os.path.join(save_dir, self.args.data_type), exist_ok=True)
|
| 396 |
+
|
| 397 |
+
# Use tab10 colormap to get visually distinct colors
|
| 398 |
+
if self.args.branches == 3:
|
| 399 |
+
branch_colors = ['#9793F8', '#50B2D7', '#D577FF'] # tuple of RGBs
|
| 400 |
+
else:
|
| 401 |
+
branch_colors = ['#50B2D7', '#D577FF'] # tuple of RGBs
|
| 402 |
+
|
| 403 |
+
# --- Plot Mass ---
|
| 404 |
+
plt.figure(figsize=(8, 5))
|
| 405 |
+
for i in range(self.branches):
|
| 406 |
+
color = branch_colors[i]
|
| 407 |
+
plt.plot(time_points, mass_over_time[i], color=color, linewidth=2.5, label=f"Mass Branch {i}")
|
| 408 |
+
plt.xlabel("Time")
|
| 409 |
+
plt.ylabel("Mass")
|
| 410 |
+
plt.title("Mass Evolution per Branch")
|
| 411 |
+
plt.legend()
|
| 412 |
+
plt.grid(True)
|
| 413 |
+
if self.joint:
|
| 414 |
+
mass_path = os.path.join(save_dir, f"{self.args.data_name}/{self.args.data_name}_joint_mass.png")
|
| 415 |
+
else:
|
| 416 |
+
mass_path = os.path.join(save_dir, f"{self.args.data_name}/{self.args.data_name}_growth_mass.png")
|
| 417 |
+
plt.savefig(mass_path, dpi=300, bbox_inches="tight")
|
| 418 |
+
plt.close()
|
| 419 |
+
|
| 420 |
+
# --- Plot Energy ---
|
| 421 |
+
plt.figure(figsize=(8, 5))
|
| 422 |
+
for i in range(self.branches):
|
| 423 |
+
color = branch_colors[i]
|
| 424 |
+
plt.plot(time_points, energy_over_time[i], color=color, linewidth=2.5, label=f"Energy Branch {i}")
|
| 425 |
+
plt.xlabel("Time")
|
| 426 |
+
plt.ylabel("Energy")
|
| 427 |
+
plt.title("Energy Evolution per Branch")
|
| 428 |
+
plt.legend()
|
| 429 |
+
plt.grid(True)
|
| 430 |
+
if self.joint:
|
| 431 |
+
energy_path = os.path.join(save_dir, f"{self.args.data_name}/{self.args.data_name}_joint_energy.png")
|
| 432 |
+
else:
|
| 433 |
+
energy_path = os.path.join(save_dir, f"{self.args.data_name}/{self.args.data_name}_growth_energy.png")
|
| 434 |
+
plt.savefig(energy_path, dpi=300, bbox_inches="tight")
|
| 435 |
+
plt.close()
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
class GrowthNetTrainLidar(GrowthNetTrain):
|
| 439 |
+
def test_step(self, batch, batch_idx):
|
| 440 |
+
main_batch = batch["test_samples"][0]
|
| 441 |
+
metric_batch = batch["metric_samples"][0]
|
| 442 |
+
|
| 443 |
+
self._plot_mass_and_energy(main_batch, metric_batch)
|
| 444 |
+
|
| 445 |
+
x0 = main_batch["x0"][0] # [B, D]
|
| 446 |
+
cloud_points = main_batch["dataset"][0] # full dataset, [N, D]
|
| 447 |
+
t_span = torch.linspace(0, 1, 101)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
all_trajs = []
|
| 451 |
+
|
| 452 |
+
for i, flow_net in enumerate(self.flow_nets):
|
| 453 |
+
node = NeuralODE(
|
| 454 |
+
flow_model_torch_wrapper(flow_net),
|
| 455 |
+
solver="euler",
|
| 456 |
+
sensitivity="adjoint",
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
with torch.no_grad():
|
| 460 |
+
traj = node.trajectory(x0, t_span).cpu() # [T, B, D]
|
| 461 |
+
|
| 462 |
+
if self.whiten:
|
| 463 |
+
traj_shape = traj.shape
|
| 464 |
+
traj = traj.reshape(-1, 3)
|
| 465 |
+
traj = self.trainer.datamodule.scaler.inverse_transform(
|
| 466 |
+
traj.cpu().detach().numpy()
|
| 467 |
+
).reshape(traj_shape)
|
| 468 |
+
|
| 469 |
+
traj = torch.tensor(traj)
|
| 470 |
+
traj = torch.transpose(traj, 0, 1) # [B, T, D]
|
| 471 |
+
all_trajs.append(traj)
|
| 472 |
+
|
| 473 |
+
# Inverse-transform the point cloud once
|
| 474 |
+
if self.whiten:
|
| 475 |
+
cloud_points = torch.tensor(
|
| 476 |
+
self.trainer.datamodule.scaler.inverse_transform(
|
| 477 |
+
cloud_points.cpu().detach().numpy()
|
| 478 |
+
)
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
# ===== Plot all trajectories together =====
|
| 482 |
+
fig = plt.figure(figsize=(6, 5))
|
| 483 |
+
ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
|
| 484 |
+
ax.view_init(elev=30, azim=-115, roll=0)
|
| 485 |
+
for i, traj in enumerate(all_trajs):
|
| 486 |
+
plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
|
| 487 |
+
if self.joint:
|
| 488 |
+
plt.savefig('./figures/lidar/joint_lidar_all_branches.png', dpi=300)
|
| 489 |
+
else:
|
| 490 |
+
plt.savefig('./figures/lidar/growth_lidar_all_branches.png', dpi=300)
|
| 491 |
+
plt.close()
|
| 492 |
+
|
| 493 |
+
# ===== Plot each trajectory separately =====
|
| 494 |
+
for i, traj in enumerate(all_trajs):
|
| 495 |
+
fig = plt.figure(figsize=(6, 5))
|
| 496 |
+
ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
|
| 497 |
+
ax.view_init(elev=30, azim=-115, roll=0)
|
| 498 |
+
plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
|
| 499 |
+
if self.joint:
|
| 500 |
+
plt.savefig(f'./figures/lidar/joint_lidar_branch_{i + 1}.png', dpi=300)
|
| 501 |
+
else:
|
| 502 |
+
plt.savefig(f'./figures/lidar/growth_lidar_branch_{i + 1}.png', dpi=300)
|
| 503 |
+
plt.close()
|
| 504 |
+
|
| 505 |
+
class GrowthNetTrainCell(GrowthNetTrain):
|
| 506 |
+
def test_step(self, batch, batch_idx):
|
| 507 |
+
if self.args.data_type in ["scrna", "tahoe"]:
|
| 508 |
+
main_batch = batch[0]["test_samples"][0]
|
| 509 |
+
metric_batch = batch[0]["metric_samples"][0]
|
| 510 |
+
else:
|
| 511 |
+
main_batch = batch["test_samples"][0]
|
| 512 |
+
metric_batch = batch["metric_samples"][0]
|
| 513 |
+
|
| 514 |
+
self._plot_mass_and_energy(main_batch, metric_batch)
|
branchsbm/branch_interpolant_train.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append("./BranchSBM")
|
| 3 |
+
import torch
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
from branchsbm.ema import EMA
|
| 6 |
+
import itertools
|
| 7 |
+
from utils import wasserstein_distance, plot_lidar
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
|
| 10 |
+
class BranchInterpolantTrain(pl.LightningModule):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
flow_matcher,
|
| 14 |
+
args,
|
| 15 |
+
skipped_time_points: list = None,
|
| 16 |
+
ot_sampler=None,
|
| 17 |
+
|
| 18 |
+
state_cost=None,
|
| 19 |
+
data_manifold_metric=None,
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.save_hyperparameters()
|
| 23 |
+
self.args = args
|
| 24 |
+
|
| 25 |
+
self.flow_matcher = flow_matcher
|
| 26 |
+
|
| 27 |
+
# list of geopath nets
|
| 28 |
+
self.geopath_nets = flow_matcher.geopath_nets
|
| 29 |
+
self.branches = len(self.geopath_nets)
|
| 30 |
+
self.metric_clusters = args.metric_clusters
|
| 31 |
+
|
| 32 |
+
self.ot_sampler = ot_sampler
|
| 33 |
+
self.skipped_time_points = skipped_time_points if skipped_time_points else []
|
| 34 |
+
self.optimizer_name = args.geopath_optimizer
|
| 35 |
+
self.lr = args.geopath_lr
|
| 36 |
+
self.weight_decay = args.geopath_weight_decay
|
| 37 |
+
self.args = args
|
| 38 |
+
self.multiply_validation = 4
|
| 39 |
+
|
| 40 |
+
self.first_loss = None
|
| 41 |
+
self.timesteps = None
|
| 42 |
+
self.computing_reference_loss = False
|
| 43 |
+
|
| 44 |
+
# updates
|
| 45 |
+
self.state_cost = state_cost
|
| 46 |
+
self.data_manifold_metric = data_manifold_metric
|
| 47 |
+
self.whiten = args.whiten
|
| 48 |
+
|
| 49 |
+
def forward(self, x0, x1, t, branch_idx):
|
| 50 |
+
# return specific branch interpolant
|
| 51 |
+
return self.geopath_nets[branch_idx](x0, x1, t)
|
| 52 |
+
|
| 53 |
+
def on_train_start(self):
|
| 54 |
+
self.first_loss = self.compute_initial_loss()
|
| 55 |
+
print("first loss")
|
| 56 |
+
print(self.first_loss)
|
| 57 |
+
|
| 58 |
+
# to edit
|
| 59 |
+
def compute_initial_loss(self):
|
| 60 |
+
# Set all GeoPath networks to eval mode
|
| 61 |
+
for net in self.geopath_nets:
|
| 62 |
+
net.train(mode=False)
|
| 63 |
+
|
| 64 |
+
total_loss = 0
|
| 65 |
+
total_count = 0
|
| 66 |
+
with torch.enable_grad():
|
| 67 |
+
self.t_val = []
|
| 68 |
+
for i in range(
|
| 69 |
+
self.trainer.datamodule.num_timesteps - len(self.skipped_time_points)
|
| 70 |
+
):
|
| 71 |
+
self.t_val.append(
|
| 72 |
+
torch.rand(
|
| 73 |
+
self.trainer.datamodule.batch_size * self.multiply_validation,
|
| 74 |
+
requires_grad=True,
|
| 75 |
+
)
|
| 76 |
+
)
|
| 77 |
+
self.computing_reference_loss = True
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
old_alpha = self.flow_matcher.alpha
|
| 80 |
+
self.flow_matcher.alpha = 0
|
| 81 |
+
for batch in self.trainer.datamodule.train_dataloader():
|
| 82 |
+
|
| 83 |
+
self.timesteps = torch.linspace(
|
| 84 |
+
0.0, 1.0, len(batch[0]["train_samples"][0])
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
loss = self._compute_loss(
|
| 88 |
+
batch[0]["train_samples"][0],
|
| 89 |
+
batch[0]["metric_samples"][0],
|
| 90 |
+
)
|
| 91 |
+
print("initial loss")
|
| 92 |
+
print(loss)
|
| 93 |
+
total_loss += loss.item()
|
| 94 |
+
total_count += 1
|
| 95 |
+
self.flow_matcher.alpha = old_alpha
|
| 96 |
+
|
| 97 |
+
self.computing_reference_loss = False
|
| 98 |
+
|
| 99 |
+
# Set all GeoPath networks back to training mode
|
| 100 |
+
for net in self.geopath_nets:
|
| 101 |
+
net.train(mode=True)
|
| 102 |
+
return total_loss / total_count if total_count > 0 else 1.0
|
| 103 |
+
|
| 104 |
+
def _compute_loss(self, main_batch, metric_samples_batch=None):
|
| 105 |
+
|
| 106 |
+
x0s = [main_batch["x0"][0]]
|
| 107 |
+
w0s = [main_batch["x0"][1]]
|
| 108 |
+
|
| 109 |
+
x1s_list = []
|
| 110 |
+
w1s_list = []
|
| 111 |
+
|
| 112 |
+
if self.branches > 1:
|
| 113 |
+
for i in range(self.branches):
|
| 114 |
+
x1s_list.append([main_batch[f"x1_{i+1}"][0]])
|
| 115 |
+
w1s_list.append([main_batch[f"x1_{i+1}"][1]])
|
| 116 |
+
else:
|
| 117 |
+
x1s_list.append([main_batch["x1"][0]])
|
| 118 |
+
w1s_list.append([main_batch["x1"][1]])
|
| 119 |
+
|
| 120 |
+
if self.args.manifold:
|
| 121 |
+
#changed
|
| 122 |
+
if self.metric_clusters == 4:
|
| 123 |
+
branch_sample_pairs = [
|
| 124 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 125 |
+
(metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
|
| 126 |
+
(metric_samples_batch[0], metric_samples_batch[3]),
|
| 127 |
+
]
|
| 128 |
+
elif self.metric_clusters == 3:
|
| 129 |
+
branch_sample_pairs = [
|
| 130 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 131 |
+
(metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
|
| 132 |
+
]
|
| 133 |
+
elif self.metric_clusters == 2 and self.branches == 2:
|
| 134 |
+
branch_sample_pairs = [
|
| 135 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 136 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_2 (branch 2)
|
| 137 |
+
]
|
| 138 |
+
else:
|
| 139 |
+
branch_sample_pairs = [
|
| 140 |
+
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
|
| 141 |
+
]
|
| 142 |
+
"""samples0, samples1, samples2 = (
|
| 143 |
+
metric_samples_batch[0],
|
| 144 |
+
metric_samples_batch[1],
|
| 145 |
+
metric_samples_batch[2]
|
| 146 |
+
)"""
|
| 147 |
+
|
| 148 |
+
assert len(x1s_list) == self.branches, "Mismatch between x1s_list and expected branches"
|
| 149 |
+
|
| 150 |
+
# compute sum of velocities for each branch
|
| 151 |
+
loss = 0
|
| 152 |
+
velocities = []
|
| 153 |
+
for branch_idx in range(self.branches):
|
| 154 |
+
|
| 155 |
+
ts, xts, uts = self._process_flow(x0s, x1s_list[branch_idx], branch_idx)
|
| 156 |
+
|
| 157 |
+
for i in range(len(ts)):
|
| 158 |
+
# calculate kinetic and potential energy of the predicted interpolant
|
| 159 |
+
if self.args.manifold:
|
| 160 |
+
start_samples, end_samples = branch_sample_pairs[branch_idx]
|
| 161 |
+
|
| 162 |
+
samples = torch.cat([start_samples, end_samples], dim=0)
|
| 163 |
+
#print("metric sample shape")
|
| 164 |
+
#print(samples.shape)
|
| 165 |
+
vel, _, _ = self.data_manifold_metric.calculate_velocity(
|
| 166 |
+
xts[i], uts[i], samples, i
|
| 167 |
+
)
|
| 168 |
+
else:
|
| 169 |
+
vel = torch.sqrt((uts[i]**2).sum(dim =-1) + self.state_cost(xts[i]))
|
| 170 |
+
#vel = (uts[i]**2).sum(dim =-1)
|
| 171 |
+
|
| 172 |
+
velocities.append(vel)
|
| 173 |
+
|
| 174 |
+
loss = torch.mean(torch.cat(velocities) ** 2)
|
| 175 |
+
|
| 176 |
+
self.log(
|
| 177 |
+
"BranchPathNet/mean_velocity_geopath",
|
| 178 |
+
loss,
|
| 179 |
+
on_step=False,
|
| 180 |
+
on_epoch=True,
|
| 181 |
+
prog_bar=True,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
return loss
|
| 185 |
+
|
| 186 |
+
def _process_flow(self, x0s, x1s, branch_idx):
|
| 187 |
+
ts, xts, uts = [], [], []
|
| 188 |
+
t_start = self.timesteps[0]
|
| 189 |
+
i_start = 0
|
| 190 |
+
|
| 191 |
+
for i, (x0, x1) in enumerate(zip(x0s, x1s)):
|
| 192 |
+
x0, x1 = torch.squeeze(x0), torch.squeeze(x1)
|
| 193 |
+
if self.trainer.validating or self.computing_reference_loss:
|
| 194 |
+
repeat_tuple = (self.multiply_validation, 1) + (1,) * (
|
| 195 |
+
len(x0.shape) - 2
|
| 196 |
+
)
|
| 197 |
+
x0 = x0.repeat(repeat_tuple)
|
| 198 |
+
x1 = x1.repeat(repeat_tuple)
|
| 199 |
+
|
| 200 |
+
if self.ot_sampler is not None:
|
| 201 |
+
x0, x1 = self.ot_sampler.sample_plan(
|
| 202 |
+
x0,
|
| 203 |
+
x1,
|
| 204 |
+
replace=True,
|
| 205 |
+
)
|
| 206 |
+
if self.skipped_time_points and i + 1 >= self.skipped_time_points[0]:
|
| 207 |
+
t_start_next = self.timesteps[i + 2]
|
| 208 |
+
else:
|
| 209 |
+
t_start_next = self.timesteps[i + 1]
|
| 210 |
+
|
| 211 |
+
t = None
|
| 212 |
+
if self.trainer.validating or self.computing_reference_loss:
|
| 213 |
+
t = self.t_val[i]
|
| 214 |
+
|
| 215 |
+
t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(
|
| 216 |
+
x0, x1, t_start, t_start_next, branch_idx, training_geopath_net=True, t=t
|
| 217 |
+
)
|
| 218 |
+
ts.append(t)
|
| 219 |
+
xts.append(xt)
|
| 220 |
+
uts.append(ut)
|
| 221 |
+
t_start = t_start_next
|
| 222 |
+
|
| 223 |
+
return ts, xts, uts
|
| 224 |
+
|
| 225 |
+
def training_step(self, batch, batch_idx):
|
| 226 |
+
if self.args.data_type in ["scrna", "tahoe"]:
|
| 227 |
+
main_batch = batch[0]["train_samples"][0]
|
| 228 |
+
metric_batch = batch[0]["metric_samples"][0]
|
| 229 |
+
else:
|
| 230 |
+
main_batch = batch["train_samples"][0]
|
| 231 |
+
metric_batch = batch["metric_samples"][0]
|
| 232 |
+
|
| 233 |
+
tangential_velocity_loss = self._compute_loss(main_batch, metric_batch)
|
| 234 |
+
|
| 235 |
+
if self.first_loss:
|
| 236 |
+
tangential_velocity_loss = tangential_velocity_loss / self.first_loss
|
| 237 |
+
|
| 238 |
+
self.log(
|
| 239 |
+
"BranchPathNet/mean_geopath_geopath",
|
| 240 |
+
(self.flow_matcher.geopath_net_output.abs().mean()),
|
| 241 |
+
on_step=False,
|
| 242 |
+
on_epoch=True,
|
| 243 |
+
prog_bar=True,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
self.log(
|
| 247 |
+
"BranchPathNet/train_loss_geopath",
|
| 248 |
+
tangential_velocity_loss,
|
| 249 |
+
on_step=True,
|
| 250 |
+
on_epoch=True,
|
| 251 |
+
prog_bar=True,
|
| 252 |
+
logger=True,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
return tangential_velocity_loss
|
| 256 |
+
|
| 257 |
+
def validation_step(self, batch, batch_idx):
|
| 258 |
+
if self.args.data_type in ["scrna", "tahoe"]:
|
| 259 |
+
main_batch = batch[0]["val_samples"][0]
|
| 260 |
+
metric_batch = batch[0]["metric_samples"][0]
|
| 261 |
+
else:
|
| 262 |
+
main_batch = batch["val_samples"][0]
|
| 263 |
+
metric_batch = batch["metric_samples"][0]
|
| 264 |
+
|
| 265 |
+
tangential_velocity_loss = self._compute_loss(main_batch, metric_batch)
|
| 266 |
+
if self.first_loss:
|
| 267 |
+
tangential_velocity_loss = tangential_velocity_loss / self.first_loss
|
| 268 |
+
|
| 269 |
+
self.log(
|
| 270 |
+
"BranchPathNet/val_loss_geopath",
|
| 271 |
+
tangential_velocity_loss,
|
| 272 |
+
on_step=False,
|
| 273 |
+
on_epoch=True,
|
| 274 |
+
prog_bar=True,
|
| 275 |
+
logger=True,
|
| 276 |
+
)
|
| 277 |
+
return tangential_velocity_loss
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def test_step(self, batch, batch_idx):
|
| 281 |
+
main_batch = batch["test_samples"][0]
|
| 282 |
+
metric_batch = batch["metric_samples"][0]
|
| 283 |
+
|
| 284 |
+
x0 = main_batch["x0"][0] # [B, D]
|
| 285 |
+
cloud_points = main_batch["dataset"][0] # full dataset, [N, D]
|
| 286 |
+
|
| 287 |
+
x0 = x0.to(self.device)
|
| 288 |
+
cloud_points = cloud_points.to(self.device)
|
| 289 |
+
|
| 290 |
+
t_vals = [0.25, 0.5, 0.75]
|
| 291 |
+
t_labels = ["t=1/4", "t=1/2", "t=3/4"]
|
| 292 |
+
|
| 293 |
+
colors = {
|
| 294 |
+
"x0": "#4D176C",
|
| 295 |
+
"t=1/4": "#5C3B9D",
|
| 296 |
+
"t=1/2": "#6172B9",
|
| 297 |
+
"t=3/4": "#AC4E51",
|
| 298 |
+
"x1": "#771F4F",
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
# Unwhiten cloud points if needed
|
| 302 |
+
if self.whiten:
|
| 303 |
+
cloud_points = torch.tensor(
|
| 304 |
+
self.trainer.datamodule.scaler.inverse_transform(cloud_points.cpu().numpy())
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
for i in range(self.branches):
|
| 308 |
+
geopath = self.geopath_nets[i]
|
| 309 |
+
x1_key = f"x1_{i + 1}"
|
| 310 |
+
if x1_key not in main_batch:
|
| 311 |
+
print(f"Skipping branch {i + 1}: no final distribution {x1_key}")
|
| 312 |
+
continue
|
| 313 |
+
|
| 314 |
+
x1 = main_batch[x1_key][0].to(self.device)
|
| 315 |
+
print(x1.shape)
|
| 316 |
+
print(x0.shape)
|
| 317 |
+
interpolated_points = []
|
| 318 |
+
with torch.no_grad():
|
| 319 |
+
for t_scalar in t_vals:
|
| 320 |
+
t_tensor = torch.full((x0.shape[0], 1), t_scalar, device=self.device) # [B, 1]
|
| 321 |
+
xt = geopath(x0, x1, t_tensor).cpu() # [B, D]
|
| 322 |
+
if self.whiten:
|
| 323 |
+
xt = torch.tensor(
|
| 324 |
+
self.trainer.datamodule.scaler.inverse_transform(xt.numpy())
|
| 325 |
+
)
|
| 326 |
+
interpolated_points.append(xt)
|
| 327 |
+
|
| 328 |
+
if self.whiten:
|
| 329 |
+
x0_plot = torch.tensor(
|
| 330 |
+
self.trainer.datamodule.scaler.inverse_transform(x0.cpu().numpy())
|
| 331 |
+
)
|
| 332 |
+
x1_plot = torch.tensor(
|
| 333 |
+
self.trainer.datamodule.scaler.inverse_transform(x1.cpu().numpy())
|
| 334 |
+
)
|
| 335 |
+
else:
|
| 336 |
+
x0_plot = x0.cpu()
|
| 337 |
+
x1_plot = x1.cpu()
|
| 338 |
+
|
| 339 |
+
# Plot
|
| 340 |
+
fig = plt.figure(figsize=(6, 5))
|
| 341 |
+
ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
|
| 342 |
+
ax.view_init(elev=30, azim=-115, roll=0)
|
| 343 |
+
plot_lidar(ax, cloud_points)
|
| 344 |
+
|
| 345 |
+
# Initial x₀
|
| 346 |
+
ax.scatter(
|
| 347 |
+
x0_plot[:, 0], x0_plot[:, 1], x0_plot[:, 2],
|
| 348 |
+
s=15, alpha=1.0, color=colors["x0"], label="x₀", depthshade=True,
|
| 349 |
+
edgecolors="white",
|
| 350 |
+
linewidths=0.3
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
# Interpolated points
|
| 354 |
+
for xt, t_label in zip(interpolated_points, t_labels):
|
| 355 |
+
ax.scatter(
|
| 356 |
+
xt[:, 0], xt[:, 1], xt[:, 2],
|
| 357 |
+
s=15, alpha=1.0, color=colors[t_label], label=t_label, depthshade=True,
|
| 358 |
+
edgecolors="white",
|
| 359 |
+
linewidths=0.3
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
# Final x₁
|
| 363 |
+
ax.scatter(
|
| 364 |
+
x1_plot[:, 0], x1_plot[:, 1], x1_plot[:, 2],
|
| 365 |
+
s=15, alpha=1.0, color=colors["x1"], label="x₁", depthshade=True,
|
| 366 |
+
edgecolors="white",
|
| 367 |
+
linewidths=0.3
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
ax.legend()
|
| 371 |
+
save_path = f"/raid/st512/branchsbm/figures/{self.args.data_type}/lidar_geopath_branch_{i+1}.png"
|
| 372 |
+
plt.savefig(save_path, dpi=300)
|
| 373 |
+
plt.close()
|
| 374 |
+
|
| 375 |
+
def optimizer_step(self, *args, **kwargs):
|
| 376 |
+
super().optimizer_step(*args, **kwargs)
|
| 377 |
+
if isinstance(self.geopath_nets, EMA):
|
| 378 |
+
self.geopath_nets.update_ema()
|
| 379 |
+
|
| 380 |
+
def configure_optimizers(self):
|
| 381 |
+
if self.optimizer_name == "adam":
|
| 382 |
+
"""optimizer = torch.optim.Adam(
|
| 383 |
+
self.geopath_nets.parameters(),
|
| 384 |
+
lr=self.lr,
|
| 385 |
+
)"""
|
| 386 |
+
optimizer = torch.optim.Adam(
|
| 387 |
+
itertools.chain(*[net.parameters() for net in self.geopath_nets]), lr=self.lr
|
| 388 |
+
)
|
| 389 |
+
elif self.optimizer_name == "adamw":
|
| 390 |
+
"""optimizer = torch.optim.AdamW(
|
| 391 |
+
self.geopath_nets.parameters(),
|
| 392 |
+
lr=self.lr,
|
| 393 |
+
weight_decay=self.weight_decay,
|
| 394 |
+
)"""
|
| 395 |
+
optimizer = torch.optim.AdamW(
|
| 396 |
+
itertools.chain(*[net.parameters() for net in self.geopath_nets]), lr=self.lr
|
| 397 |
+
)
|
| 398 |
+
return optimizer
|
branchsbm/branchsbm.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append("./BranchSBM")
|
| 3 |
+
import torch
|
| 4 |
+
from torchcfm.conditional_flow_matching import ConditionalFlowMatcher, pad_t_like_x
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
class BranchSBM(ConditionalFlowMatcher):
|
| 8 |
+
def __init__(
|
| 9 |
+
self, geopath_nets: nn.ModuleList = None, alpha: float = 1.0, *args, **kwargs
|
| 10 |
+
):
|
| 11 |
+
super().__init__(*args, **kwargs)
|
| 12 |
+
self.alpha = alpha
|
| 13 |
+
self.geopath_nets = geopath_nets
|
| 14 |
+
if self.alpha != 0:
|
| 15 |
+
assert (
|
| 16 |
+
geopath_nets is not None
|
| 17 |
+
), "GeoPath model must be provided if alpha != 0"
|
| 18 |
+
|
| 19 |
+
self.branches = len(geopath_nets)
|
| 20 |
+
|
| 21 |
+
def gamma(self, t, t_min, t_max):
|
| 22 |
+
return (
|
| 23 |
+
1.0
|
| 24 |
+
- ((t - t_min) / (t_max - t_min)) ** 2
|
| 25 |
+
- ((t_max - t) / (t_max - t_min)) ** 2
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def d_gamma(self, t, t_min, t_max):
|
| 29 |
+
return 2 * (-2 * t + t_max + t_min) / (t_max - t_min) ** 2
|
| 30 |
+
|
| 31 |
+
def compute_mu_t(self, x0, x1, t, t_min, t_max, branch_idx):
|
| 32 |
+
assert branch_idx < self.branches, "Index out of bounds"
|
| 33 |
+
|
| 34 |
+
with torch.enable_grad():
|
| 35 |
+
t = pad_t_like_x(t, x0)
|
| 36 |
+
if self.alpha == 0:
|
| 37 |
+
return (t_max - t) / (t_max - t_min) * x0 + (t - t_min) / (
|
| 38 |
+
t_max - t_min
|
| 39 |
+
) * x1
|
| 40 |
+
|
| 41 |
+
# compute value for specific branch
|
| 42 |
+
self.geopath_net_output = self.geopath_nets[branch_idx](x0, x1, t)
|
| 43 |
+
if self.geopath_nets[branch_idx].time_geopath:
|
| 44 |
+
self.doutput_dt = torch.autograd.grad(
|
| 45 |
+
self.geopath_net_output,
|
| 46 |
+
t,
|
| 47 |
+
grad_outputs=torch.ones_like(self.geopath_net_output),
|
| 48 |
+
create_graph=False,
|
| 49 |
+
retain_graph=True,
|
| 50 |
+
)[0]
|
| 51 |
+
return (
|
| 52 |
+
(t_max - t) / (t_max - t_min) * x0
|
| 53 |
+
+ (t - t_min) / (t_max - t_min) * x1
|
| 54 |
+
+ self.gamma(t, t_min, t_max) * self.geopath_net_output
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def sample_xt(self, x0, x1, t, epsilon, t_min, t_max, branch_idx):
|
| 58 |
+
assert branch_idx < self.branches, "Index out of bounds"
|
| 59 |
+
mu_t = self.compute_mu_t(x0, x1, t, t_min, t_max, branch_idx)
|
| 60 |
+
sigma_t = self.compute_sigma_t(t)
|
| 61 |
+
sigma_t = pad_t_like_x(sigma_t, x0)
|
| 62 |
+
return mu_t + sigma_t * epsilon
|
| 63 |
+
|
| 64 |
+
def sample_location_and_conditional_flow(
|
| 65 |
+
self,
|
| 66 |
+
x0,
|
| 67 |
+
x1,
|
| 68 |
+
t_min,
|
| 69 |
+
t_max,
|
| 70 |
+
branch_idx,
|
| 71 |
+
training_geopath_net=False,
|
| 72 |
+
midpoint_only=False,
|
| 73 |
+
t=None,
|
| 74 |
+
):
|
| 75 |
+
|
| 76 |
+
self.training_geopath_net = training_geopath_net
|
| 77 |
+
with torch.enable_grad():
|
| 78 |
+
if t is None:
|
| 79 |
+
t = torch.rand(x0.shape[0], requires_grad=True)
|
| 80 |
+
t = t.type_as(x0)
|
| 81 |
+
t = t * (t_max - t_min) + t_min
|
| 82 |
+
if midpoint_only:
|
| 83 |
+
t = (t_max + t_min) / 2 * torch.ones_like(t).type_as(x0)
|
| 84 |
+
|
| 85 |
+
assert len(t) == x0.shape[0], "t has to have batch size dimension"
|
| 86 |
+
|
| 87 |
+
eps = self.sample_noise_like(x0)
|
| 88 |
+
|
| 89 |
+
# compute xt and ut for branch_idx
|
| 90 |
+
xt = self.sample_xt(x0, x1, t, eps, t_min, t_max, branch_idx)
|
| 91 |
+
ut = self.compute_conditional_flow(x0, x1, t, xt, t_min, t_max, branch_idx)
|
| 92 |
+
|
| 93 |
+
return t, xt, ut
|
| 94 |
+
|
| 95 |
+
def compute_conditional_flow(self, x0, x1, t, xt, t_min, t_max, branch_idx):
|
| 96 |
+
del xt
|
| 97 |
+
t = pad_t_like_x(t, x0)
|
| 98 |
+
if self.alpha == 0:
|
| 99 |
+
return (x1 - x0) / (t_max - t_min)
|
| 100 |
+
|
| 101 |
+
return (
|
| 102 |
+
(x1 - x0) / (t_max - t_min)
|
| 103 |
+
+ self.d_gamma(t, t_min, t_max) * self.geopath_net_output
|
| 104 |
+
+ (
|
| 105 |
+
self.gamma(t, t_min, t_max) * self.doutput_dt
|
| 106 |
+
if self.geopath_nets[branch_idx].time_geopath
|
| 107 |
+
else 0
|
| 108 |
+
)
|
| 109 |
+
)
|
branchsbm/ema.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
class EMA(torch.nn.Module):
|
| 4 |
+
def __init__(self, model: torch.nn.Module, decay: float = 0.999):
|
| 5 |
+
super().__init__()
|
| 6 |
+
self.model = model
|
| 7 |
+
self.decay = decay
|
| 8 |
+
if hasattr(self.model, "time_geopath"):
|
| 9 |
+
self.time_geopath = self.model.time_geopath
|
| 10 |
+
|
| 11 |
+
# Put this in a buffer so that it gets included in the state dict
|
| 12 |
+
self.register_buffer("num_updates", torch.tensor(0))
|
| 13 |
+
|
| 14 |
+
self.shadow_params = torch.nn.ParameterList(
|
| 15 |
+
[
|
| 16 |
+
torch.nn.Parameter(p.clone().detach(), requires_grad=False)
|
| 17 |
+
for p in model.parameters()
|
| 18 |
+
if p.requires_grad
|
| 19 |
+
]
|
| 20 |
+
)
|
| 21 |
+
self.backup_params = []
|
| 22 |
+
|
| 23 |
+
def train(self, mode: bool):
|
| 24 |
+
if self.training and mode == False:
|
| 25 |
+
# Switching from train mode to eval mode. Backup the model parameters and
|
| 26 |
+
# overwrite with shadow params
|
| 27 |
+
self.backup()
|
| 28 |
+
self.copy_to_model()
|
| 29 |
+
elif not self.training and mode == True:
|
| 30 |
+
# Switching from eval to train mode. Restore the `backup_params`
|
| 31 |
+
self.restore_to_model()
|
| 32 |
+
|
| 33 |
+
super().train(mode)
|
| 34 |
+
|
| 35 |
+
def update_ema(self):
|
| 36 |
+
self.num_updates += 1
|
| 37 |
+
num_updates = self.num_updates.item()
|
| 38 |
+
decay = min(self.decay, (1 + num_updates) / (10 + num_updates))
|
| 39 |
+
with torch.no_grad():
|
| 40 |
+
params = [p for p in self.model.parameters() if p.requires_grad]
|
| 41 |
+
for shadow, param in zip(self.shadow_params, params):
|
| 42 |
+
shadow.sub_((1 - decay) * (shadow - param))
|
| 43 |
+
|
| 44 |
+
def forward(self, *args, **kwargs):
|
| 45 |
+
return self.model(*args, **kwargs)
|
| 46 |
+
|
| 47 |
+
def copy_to_model(self):
|
| 48 |
+
# copy the shadow (ema) parameters to the model
|
| 49 |
+
params = [p for p in self.model.parameters() if p.requires_grad]
|
| 50 |
+
for shaddow, param in zip(self.shadow_params, params):
|
| 51 |
+
param.data.copy_(shaddow.data)
|
| 52 |
+
|
| 53 |
+
def backup(self):
|
| 54 |
+
# Backup the current model parameters
|
| 55 |
+
if len(self.backup_params) > 0:
|
| 56 |
+
for p, b in zip(self.model.parameters(), self.backup_params):
|
| 57 |
+
b.data.copy_(p.data)
|
| 58 |
+
else:
|
| 59 |
+
self.backup_params = [param.clone() for param in self.model.parameters()]
|
| 60 |
+
|
| 61 |
+
def restore_to_model(self):
|
| 62 |
+
# Restores the backed up parameters to the model.
|
| 63 |
+
for param, backup in zip(self.model.parameters(), self.backup_params):
|
| 64 |
+
param.data.copy_(backup.data)
|
configs/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
configs/experiment/cell_single_branch.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: "scrna"
|
| 2 |
+
data_name: "mouse"
|
| 3 |
+
dim: 2
|
| 4 |
+
whiten: false
|
| 5 |
+
t_exclude: []
|
| 6 |
+
velocity_metric: "land"
|
| 7 |
+
gammas: [0.125]
|
| 8 |
+
rho: 0.001
|
| 9 |
+
branchsbm: true
|
| 10 |
+
seeds: [42]
|
| 11 |
+
patience_geopath: 50
|
| 12 |
+
time_geopath: true
|
configs/experiment/clonidine_100D.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: "tahoe"
|
| 2 |
+
data_name: "clonidine100D"
|
| 3 |
+
accelerator: "gpu"
|
| 4 |
+
hidden_dims_geopath: [1024, 1024, 1024]
|
| 5 |
+
hidden_dims_flow: [1024, 1024, 1024]
|
| 6 |
+
hidden_dims_growth: [1024, 1024, 1024]
|
| 7 |
+
dim: 100
|
| 8 |
+
t_exclude: []
|
| 9 |
+
time_geopath: true
|
| 10 |
+
whiten: false
|
| 11 |
+
velocity_metric: "rbf"
|
| 12 |
+
metric_patience: 25
|
| 13 |
+
patience: 25
|
| 14 |
+
n_centers: 300
|
| 15 |
+
kappa: 2
|
| 16 |
+
rho: -2.75
|
| 17 |
+
alpha_metric: 1
|
| 18 |
+
metric_epochs: 200
|
| 19 |
+
branchsbm: true
|
| 20 |
+
seeds: [42]
|
| 21 |
+
branches: 2
|
| 22 |
+
metric_clusters: 3
|
configs/experiment/clonidine_150D.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: "tahoe"
|
| 2 |
+
data_name: "clonidine150D"
|
| 3 |
+
accelerator: "gpu"
|
| 4 |
+
hidden_dims_geopath: [1024, 1024, 1024]
|
| 5 |
+
hidden_dims_flow: [1024, 1024, 1024]
|
| 6 |
+
hidden_dims_growth: [1024, 1024, 1024]
|
| 7 |
+
dim: 150
|
| 8 |
+
t_exclude: []
|
| 9 |
+
time_geopath: true
|
| 10 |
+
whiten: false
|
| 11 |
+
velocity_metric: "rbf"
|
| 12 |
+
metric_patience: 25
|
| 13 |
+
patience: 25
|
| 14 |
+
n_centers: 300
|
| 15 |
+
kappa: 3
|
| 16 |
+
rho: -2.75
|
| 17 |
+
alpha_metric: 1
|
| 18 |
+
metric_epochs: 400
|
| 19 |
+
branchsbm: true
|
| 20 |
+
seeds: [42]
|
| 21 |
+
branches: 2
|
| 22 |
+
metric_clusters: 3
|
configs/experiment/clonidine_50D.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: "tahoe"
|
| 2 |
+
data_name: "clonidine50D"
|
| 3 |
+
accelerator: "gpu"
|
| 4 |
+
hidden_dims_geopath: [1024, 1024, 1024]
|
| 5 |
+
hidden_dims_flow: [1024, 1024, 1024]
|
| 6 |
+
hidden_dims_growth: [1024, 1024, 1024]
|
| 7 |
+
dim: 50
|
| 8 |
+
t_exclude: []
|
| 9 |
+
time_geopath: true
|
| 10 |
+
whiten: false
|
| 11 |
+
velocity_metric: "rbf"
|
| 12 |
+
metric_patience: 25
|
| 13 |
+
patience: 25
|
| 14 |
+
n_centers: 150
|
| 15 |
+
kappa: 1.5
|
| 16 |
+
rho: -2.75
|
| 17 |
+
alpha_metric: 1
|
| 18 |
+
metric_epochs: 200
|
| 19 |
+
branchsbm: true
|
| 20 |
+
seeds: [42]
|
| 21 |
+
branches: 2
|
| 22 |
+
metric_clusters: 3
|
configs/experiment/clonidine_50Dsingle.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: "tahoe"
|
| 2 |
+
data_name: "clonidine50Dsingle"
|
| 3 |
+
accelerator: "gpu"
|
| 4 |
+
hidden_dims_geopath: [1024, 1024, 1024]
|
| 5 |
+
hidden_dims_flow: [1024, 1024, 1024]
|
| 6 |
+
hidden_dims_growth: [1024, 1024, 1024]
|
| 7 |
+
dim: 50
|
| 8 |
+
t_exclude: []
|
| 9 |
+
time_geopath: true
|
| 10 |
+
whiten: false
|
| 11 |
+
velocity_metric: "rbf"
|
| 12 |
+
metric_patience: 25
|
| 13 |
+
patience: 25
|
| 14 |
+
n_centers: 150
|
| 15 |
+
kappa: 1.5
|
| 16 |
+
rho: -2.75
|
| 17 |
+
alpha_metric: 1
|
| 18 |
+
metric_epochs: 200
|
| 19 |
+
branchsbm: true
|
| 20 |
+
seeds: [42]
|
| 21 |
+
branches: 1
|
| 22 |
+
metric_clusters: 2
|
configs/experiment/lidar.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: "lidar"
|
| 2 |
+
data_name: "lidar"
|
| 3 |
+
dim: 3
|
| 4 |
+
whiten: true
|
| 5 |
+
t_exclude: []
|
| 6 |
+
velocity_metric: "land"
|
| 7 |
+
gammas: [0.125]
|
| 8 |
+
rho: 0.001
|
| 9 |
+
branchsbm: true
|
| 10 |
+
seeds: [42]
|
| 11 |
+
patience_geopath: 50
|
| 12 |
+
time_geopath: true
|
| 13 |
+
branches: 2
|
| 14 |
+
metric_clusters: 3
|
configs/experiment/lidar_single.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: "lidar"
|
| 2 |
+
data_name: "lidarsingle"
|
| 3 |
+
dim: 3
|
| 4 |
+
whiten: true
|
| 5 |
+
t_exclude: []
|
| 6 |
+
velocity_metric: "land"
|
| 7 |
+
gammas: [0.125]
|
| 8 |
+
rho: 0.001
|
| 9 |
+
branchsbm: true
|
| 10 |
+
seeds: [42]
|
| 11 |
+
patience_geopath: 50
|
| 12 |
+
time_geopath: true
|
| 13 |
+
branches: 1
|
| 14 |
+
metric_clusters: 2
|
configs/experiment/mouse.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: "scrna"
|
| 2 |
+
data_name: "mouse"
|
| 3 |
+
hidden_dims_geopath: [1024, 1024, 1024]
|
| 4 |
+
hidden_dims_flow: [1024, 1024, 1024]
|
| 5 |
+
hidden_dims_growth: [1024, 1024, 1024]
|
| 6 |
+
dim: 2
|
| 7 |
+
whiten: false
|
| 8 |
+
t_exclude: []
|
| 9 |
+
velocity_metric: "land"
|
| 10 |
+
gammas: [0.125]
|
| 11 |
+
rho: 0.001
|
| 12 |
+
branchsbm: true
|
| 13 |
+
seeds: [42]
|
| 14 |
+
patience_geopath: 50
|
| 15 |
+
time_geopath: true
|
| 16 |
+
branches: 2
|
| 17 |
+
metric_clusters: 3
|
configs/experiment/trametinib.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: "tahoe"
|
| 2 |
+
data_name: "trametinib"
|
| 3 |
+
accelerator: "gpu"
|
| 4 |
+
hidden_dims_geopath: [1024, 1024, 1024]
|
| 5 |
+
hidden_dims_flow: [1024, 1024, 1024]
|
| 6 |
+
hidden_dims_growth: [1024, 1024, 1024]
|
| 7 |
+
dim: 50
|
| 8 |
+
t_exclude: []
|
| 9 |
+
time_geopath: true
|
| 10 |
+
whiten: false
|
| 11 |
+
velocity_metric: "rbf"
|
| 12 |
+
metric_patience: 25
|
| 13 |
+
patience: 25
|
| 14 |
+
n_centers: 150
|
| 15 |
+
kappa: 1.5
|
| 16 |
+
rho: -2.75
|
| 17 |
+
alpha_metric: 1
|
| 18 |
+
metric_epochs: 200
|
| 19 |
+
branchsbm: true
|
| 20 |
+
seeds: [42]
|
| 21 |
+
branches: 3
|
| 22 |
+
metric_clusters: 4
|
configs/experiment/trametinib_single.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: "tahoe"
|
| 2 |
+
data_name: "trametinibsingle"
|
| 3 |
+
accelerator: "gpu"
|
| 4 |
+
hidden_dims_geopath: [1024, 1024, 1024]
|
| 5 |
+
hidden_dims_flow: [1024, 1024, 1024]
|
| 6 |
+
hidden_dims_growth: [1024, 1024, 1024]
|
| 7 |
+
dim: 50
|
| 8 |
+
t_exclude: []
|
| 9 |
+
time_geopath: true
|
| 10 |
+
whiten: false
|
| 11 |
+
velocity_metric: "rbf"
|
| 12 |
+
metric_patience: 25
|
| 13 |
+
patience: 25
|
| 14 |
+
n_centers: 150
|
| 15 |
+
kappa: 1.5
|
| 16 |
+
rho: -2.75
|
| 17 |
+
alpha_metric: 1
|
| 18 |
+
metric_epochs: 200
|
| 19 |
+
branchsbm: true
|
| 20 |
+
seeds: [42]
|
| 21 |
+
branches: 1
|
| 22 |
+
metric_clusters: 2
|
dataloaders/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
dataloaders/clonidine_data.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
sys.argv = ['']
|
| 4 |
+
from sklearn.preprocessing import StandardScaler
|
| 5 |
+
import pytorch_lightning as pl
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from lightning.pytorch.utilities.combined_loader import CombinedLoader
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
from functools import partial
|
| 11 |
+
from scipy.spatial import cKDTree
|
| 12 |
+
from sklearn.cluster import KMeans
|
| 13 |
+
from torch.utils.data import TensorDataset
|
| 14 |
+
|
| 15 |
+
#from train.parsers_tahoe import parse_args
|
| 16 |
+
#args = parse_args()
|
| 17 |
+
|
| 18 |
+
class DrugResponseDataModule(pl.LightningDataModule):
|
| 19 |
+
def __init__(self, args):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.save_hyperparameters()
|
| 22 |
+
|
| 23 |
+
self.batch_size = args.batch_size
|
| 24 |
+
self.max_dim = args.dim
|
| 25 |
+
self.whiten = args.whiten
|
| 26 |
+
self.split_ratios = args.split_ratios
|
| 27 |
+
|
| 28 |
+
# Path to your combined data
|
| 29 |
+
self.data_path = "/raid/st512/branchsbm/data/pca_and_leiden_labels.csv"
|
| 30 |
+
self.num_timesteps = 2
|
| 31 |
+
self.args = args
|
| 32 |
+
self._prepare_data()
|
| 33 |
+
|
| 34 |
+
def _prepare_data(self):
|
| 35 |
+
df = pd.read_csv(self.data_path, comment='#')
|
| 36 |
+
df = df.iloc[:, 1:]
|
| 37 |
+
df = df.replace('', np.nan)
|
| 38 |
+
pc_cols = df.columns[:50]
|
| 39 |
+
for col in pc_cols:
|
| 40 |
+
df[col] = pd.to_numeric(df[col], errors='coerce')
|
| 41 |
+
leiden_dmso_col = 'leiden_DMSO_TF_0.0uM'
|
| 42 |
+
leiden_clonidine_col = 'leiden_Clonidine (hydrochloride)_5.0uM'
|
| 43 |
+
|
| 44 |
+
dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column
|
| 45 |
+
clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column
|
| 46 |
+
|
| 47 |
+
dmso_data = df[dmso_mask].copy()
|
| 48 |
+
clonidine_data = df[clonidine_mask].copy()
|
| 49 |
+
|
| 50 |
+
top_clonidine_clusters = ['0.0', '4.0']
|
| 51 |
+
|
| 52 |
+
x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]]
|
| 53 |
+
x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]]
|
| 54 |
+
x1_1_coords = x1_1_data[pc_cols].values
|
| 55 |
+
x1_2_coords = x1_2_data[pc_cols].values
|
| 56 |
+
x1_1_coords = x1_1_coords.astype(float)
|
| 57 |
+
x1_2_coords = x1_2_coords.astype(float)
|
| 58 |
+
|
| 59 |
+
target_size = min(len(x1_1_coords), len(x1_2_coords))
|
| 60 |
+
|
| 61 |
+
# Sample endpoint clusters to target size
|
| 62 |
+
np.random.seed(42)
|
| 63 |
+
if len(x1_1_coords) > target_size:
|
| 64 |
+
idx1 = np.random.choice(len(x1_1_coords), target_size, replace=False)
|
| 65 |
+
x1_1_coords = x1_1_coords[idx1]
|
| 66 |
+
|
| 67 |
+
if len(x1_2_coords) > target_size:
|
| 68 |
+
idx2 = np.random.choice(len(x1_2_coords), target_size, replace=False)
|
| 69 |
+
x1_2_coords = x1_2_coords[idx2]
|
| 70 |
+
|
| 71 |
+
dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts()
|
| 72 |
+
|
| 73 |
+
# DMSO
|
| 74 |
+
largest_dmso_cluster = dmso_cluster_counts.index[0]
|
| 75 |
+
dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster]
|
| 76 |
+
|
| 77 |
+
dmso_coords = dmso_cluster_data[pc_cols].values
|
| 78 |
+
|
| 79 |
+
# Random sampling from largest DMSO cluster to match target size
|
| 80 |
+
np.random.seed(42)
|
| 81 |
+
if len(dmso_coords) >= target_size:
|
| 82 |
+
idx0 = np.random.choice(len(dmso_coords), target_size, replace=False)
|
| 83 |
+
x0_coords = dmso_coords[idx0]
|
| 84 |
+
else:
|
| 85 |
+
# If largest cluster is smaller than target, use all of it and pad with other DMSO cells
|
| 86 |
+
remaining_needed = target_size - len(dmso_coords)
|
| 87 |
+
other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster]
|
| 88 |
+
other_dmso_coords = other_dmso_data[pc_cols].values
|
| 89 |
+
|
| 90 |
+
if len(other_dmso_coords) >= remaining_needed:
|
| 91 |
+
idx_other = np.random.choice(len(other_dmso_coords), remaining_needed, replace=False)
|
| 92 |
+
x0_coords = np.vstack([dmso_coords, other_dmso_coords[idx_other]])
|
| 93 |
+
else:
|
| 94 |
+
# Use all available DMSO cells and reduce target size
|
| 95 |
+
all_dmso_coords = dmso_data[pc_cols].values
|
| 96 |
+
target_size = min(target_size, len(all_dmso_coords))
|
| 97 |
+
idx0 = np.random.choice(len(all_dmso_coords), target_size, replace=False)
|
| 98 |
+
x0_coords = all_dmso_coords[idx0]
|
| 99 |
+
|
| 100 |
+
# Also resample endpoint clusters to match final target size
|
| 101 |
+
if len(x1_1_coords) > target_size:
|
| 102 |
+
idx1 = np.random.choice(len(x1_1_coords), target_size, replace=False)
|
| 103 |
+
x1_1_coords = x1_1_coords[idx1]
|
| 104 |
+
|
| 105 |
+
if len(x1_2_coords) > target_size:
|
| 106 |
+
idx2 = np.random.choice(len(x1_2_coords), target_size, replace=False)
|
| 107 |
+
x1_2_coords = x1_2_coords[idx2]
|
| 108 |
+
|
| 109 |
+
self.n_samples = target_size
|
| 110 |
+
|
| 111 |
+
x0 = torch.tensor(x0_coords, dtype=torch.float32)
|
| 112 |
+
x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
|
| 113 |
+
x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
|
| 114 |
+
|
| 115 |
+
self.coords_t0 = x0
|
| 116 |
+
self.coords_t1 = torch.cat([x1_1, x1_2], dim=0)
|
| 117 |
+
self.time_labels = np.concatenate([
|
| 118 |
+
np.zeros(len(self.coords_t0)), # t=0
|
| 119 |
+
np.ones(len(self.coords_t1)), # t=1
|
| 120 |
+
])
|
| 121 |
+
|
| 122 |
+
split_index = int(target_size * self.split_ratios[0])
|
| 123 |
+
|
| 124 |
+
if target_size - split_index < self.batch_size:
|
| 125 |
+
split_index = target_size - self.batch_size
|
| 126 |
+
|
| 127 |
+
train_x0 = x0[:split_index]
|
| 128 |
+
val_x0 = x0[split_index:]
|
| 129 |
+
train_x1_1 = x1_1[:split_index]
|
| 130 |
+
val_x1_1 = x1_1[split_index:]
|
| 131 |
+
train_x1_2 = x1_2[:split_index]
|
| 132 |
+
val_x1_2 = x1_2[split_index:]
|
| 133 |
+
|
| 134 |
+
self.val_x0 = val_x0
|
| 135 |
+
|
| 136 |
+
train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
|
| 137 |
+
train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.5)
|
| 138 |
+
train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.5)
|
| 139 |
+
|
| 140 |
+
val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
|
| 141 |
+
val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.5)
|
| 142 |
+
val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.5)
|
| 143 |
+
|
| 144 |
+
self.train_dataloaders = {
|
| 145 |
+
"x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 146 |
+
"x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 147 |
+
"x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
self.val_dataloaders = {
|
| 151 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
|
| 152 |
+
"x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 153 |
+
"x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
all_coords = df[pc_cols].dropna().values.astype(float)
|
| 157 |
+
self.dataset = torch.tensor(all_coords, dtype=torch.float32)
|
| 158 |
+
self.tree = cKDTree(all_coords)
|
| 159 |
+
|
| 160 |
+
self.test_dataloaders = {
|
| 161 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
|
| 162 |
+
"dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
# Metric samples
|
| 166 |
+
km_all = KMeans(n_clusters=3, random_state=0).fit(self.dataset.numpy())
|
| 167 |
+
cluster_labels = km_all.labels_
|
| 168 |
+
|
| 169 |
+
cluster_0_mask = cluster_labels == 0
|
| 170 |
+
cluster_1_mask = cluster_labels == 1
|
| 171 |
+
cluster_2_mask = cluster_labels == 2
|
| 172 |
+
|
| 173 |
+
samples = self.dataset.cpu().numpy()
|
| 174 |
+
|
| 175 |
+
cluster_0_data = samples[cluster_0_mask]
|
| 176 |
+
cluster_1_data = samples[cluster_1_mask]
|
| 177 |
+
cluster_2_data = samples[cluster_2_mask]
|
| 178 |
+
|
| 179 |
+
self.metric_samples_dataloaders = [
|
| 180 |
+
DataLoader(
|
| 181 |
+
torch.tensor(cluster_2_data, dtype=torch.float32),
|
| 182 |
+
batch_size=cluster_2_data.shape[0],
|
| 183 |
+
shuffle=False,
|
| 184 |
+
drop_last=False,
|
| 185 |
+
),
|
| 186 |
+
DataLoader(
|
| 187 |
+
torch.tensor(cluster_0_data, dtype=torch.float32),
|
| 188 |
+
batch_size=cluster_0_data.shape[0],
|
| 189 |
+
shuffle=False,
|
| 190 |
+
drop_last=False,
|
| 191 |
+
),
|
| 192 |
+
|
| 193 |
+
DataLoader(
|
| 194 |
+
torch.tensor(cluster_1_data, dtype=torch.float32),
|
| 195 |
+
batch_size=cluster_1_data.shape[0],
|
| 196 |
+
shuffle=False,
|
| 197 |
+
drop_last=False,
|
| 198 |
+
),
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
+
def train_dataloader(self):
|
| 202 |
+
combined_loaders = {
|
| 203 |
+
"train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
|
| 204 |
+
"metric_samples": CombinedLoader(
|
| 205 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 206 |
+
),
|
| 207 |
+
}
|
| 208 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 209 |
+
|
| 210 |
+
def val_dataloader(self):
|
| 211 |
+
combined_loaders = {
|
| 212 |
+
"val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
|
| 213 |
+
"metric_samples": CombinedLoader(
|
| 214 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 215 |
+
),
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def test_dataloader(self):
|
| 223 |
+
combined_loaders = {
|
| 224 |
+
"test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
|
| 225 |
+
"metric_samples": CombinedLoader(
|
| 226 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 227 |
+
),
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 231 |
+
|
| 232 |
+
def get_manifold_proj(self, points):
|
| 233 |
+
"""Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
|
| 234 |
+
return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
|
| 235 |
+
|
| 236 |
+
@staticmethod
|
| 237 |
+
def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
|
| 238 |
+
"""
|
| 239 |
+
Apply local smoothing based on k-nearest neighbors in the full dataset
|
| 240 |
+
This replaces the plane projection for 2D manifold regularization
|
| 241 |
+
"""
|
| 242 |
+
points_np = x.detach().cpu().numpy()
|
| 243 |
+
_, idx = tree.query(points_np, k=k)
|
| 244 |
+
nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
|
| 245 |
+
|
| 246 |
+
# Compute weighted average of neighbors
|
| 247 |
+
dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
|
| 248 |
+
weights = torch.exp(-dists / temp)
|
| 249 |
+
weights = weights / weights.sum(dim=1, keepdim=True)
|
| 250 |
+
|
| 251 |
+
# Weighted average of neighbors
|
| 252 |
+
smoothed = (weights * nearest_pts).sum(dim=1)
|
| 253 |
+
|
| 254 |
+
# Blend original point with smoothed version
|
| 255 |
+
alpha = 0.3 # How much smoothing to apply
|
| 256 |
+
return (1 - alpha) * x + alpha * smoothed
|
| 257 |
+
|
| 258 |
+
def get_timepoint_data(self):
|
| 259 |
+
"""Return data organized by timepoints for visualization"""
|
| 260 |
+
return {
|
| 261 |
+
't0': self.coords_t0,
|
| 262 |
+
't1': self.coords_t1,
|
| 263 |
+
'time_labels': self.time_labels
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
def get_datamodule():
|
| 267 |
+
datamodule = DrugResponseDataModule(args)
|
| 268 |
+
datamodule.setup(stage="fit")
|
| 269 |
+
return datamodule
|
dataloaders/clonidine_single_branch.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
sys.argv = ['']
|
| 4 |
+
from sklearn.preprocessing import StandardScaler
|
| 5 |
+
import pytorch_lightning as pl
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from lightning.pytorch.utilities.combined_loader import CombinedLoader
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
from functools import partial
|
| 11 |
+
from scipy.spatial import cKDTree
|
| 12 |
+
from sklearn.cluster import KMeans
|
| 13 |
+
from torch.utils.data import TensorDataset
|
| 14 |
+
|
| 15 |
+
#uncomment for plotting
|
| 16 |
+
#from train.parsers_tahoe import parse_args
|
| 17 |
+
#args = parse_args()
|
| 18 |
+
|
| 19 |
+
class ClonidineSingleBranchDataModule(pl.LightningDataModule):
|
| 20 |
+
def __init__(self, args):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.save_hyperparameters()
|
| 23 |
+
|
| 24 |
+
self.batch_size = args.batch_size
|
| 25 |
+
self.max_dim = args.dim
|
| 26 |
+
self.whiten = args.whiten
|
| 27 |
+
self.split_ratios = args.split_ratios
|
| 28 |
+
|
| 29 |
+
self.dim = args.dim
|
| 30 |
+
print("dimension")
|
| 31 |
+
print(self.dim)
|
| 32 |
+
# Path to your combined data
|
| 33 |
+
self.data_path = "./data/pca_and_leiden_labels.csv"
|
| 34 |
+
self.num_timesteps = 2
|
| 35 |
+
self.args = args
|
| 36 |
+
self._prepare_data()
|
| 37 |
+
|
| 38 |
+
def _prepare_data(self):
|
| 39 |
+
df = pd.read_csv(self.data_path, comment='#')
|
| 40 |
+
df = df.iloc[:, 1:]
|
| 41 |
+
df = df.replace('', np.nan)
|
| 42 |
+
pc_cols = df.columns[:self.dim]
|
| 43 |
+
for col in pc_cols:
|
| 44 |
+
df[col] = pd.to_numeric(df[col], errors='coerce')
|
| 45 |
+
leiden_dmso_col = 'leiden_DMSO_TF_0.0uM'
|
| 46 |
+
leiden_clonidine_col = 'leiden_Clonidine (hydrochloride)_5.0uM'
|
| 47 |
+
|
| 48 |
+
dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column
|
| 49 |
+
clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column
|
| 50 |
+
|
| 51 |
+
dmso_data = df[dmso_mask].copy()
|
| 52 |
+
clonidine_data = df[clonidine_mask].copy()
|
| 53 |
+
|
| 54 |
+
top_clonidine_clusters = ['0.0', '4.0']
|
| 55 |
+
|
| 56 |
+
x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]]
|
| 57 |
+
x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]]
|
| 58 |
+
|
| 59 |
+
x1_1_coords = x1_1_data[pc_cols].values
|
| 60 |
+
x1_2_coords = x1_2_data[pc_cols].values
|
| 61 |
+
|
| 62 |
+
x1_1_coords = x1_1_coords.astype(float)
|
| 63 |
+
x1_2_coords = x1_2_coords.astype(float)
|
| 64 |
+
|
| 65 |
+
# Target size is now the minimum across all three endpoint clusters
|
| 66 |
+
target_size = min(len(x1_1_coords), len(x1_2_coords),)
|
| 67 |
+
|
| 68 |
+
# Helper function to select points closest to centroid
|
| 69 |
+
def select_closest_to_centroid(coords, target_size):
|
| 70 |
+
if len(coords) <= target_size:
|
| 71 |
+
return coords
|
| 72 |
+
|
| 73 |
+
# Calculate centroid
|
| 74 |
+
centroid = np.mean(coords, axis=0)
|
| 75 |
+
|
| 76 |
+
# Calculate distances to centroid
|
| 77 |
+
distances = np.linalg.norm(coords - centroid, axis=1)
|
| 78 |
+
|
| 79 |
+
# Get indices of closest points
|
| 80 |
+
closest_indices = np.argsort(distances)[:target_size]
|
| 81 |
+
|
| 82 |
+
return coords[closest_indices]
|
| 83 |
+
|
| 84 |
+
# Sample all endpoint clusters to target size using centroid-based selection
|
| 85 |
+
x1_1_coords = select_closest_to_centroid(x1_1_coords, target_size)
|
| 86 |
+
x1_2_coords = select_closest_to_centroid(x1_2_coords, target_size)
|
| 87 |
+
|
| 88 |
+
dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts()
|
| 89 |
+
|
| 90 |
+
# DMSO (unchanged)
|
| 91 |
+
largest_dmso_cluster = dmso_cluster_counts.index[0]
|
| 92 |
+
dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster]
|
| 93 |
+
|
| 94 |
+
dmso_coords = dmso_cluster_data[pc_cols].values
|
| 95 |
+
|
| 96 |
+
# Random sampling from largest DMSO cluster to match target size
|
| 97 |
+
# For DMSO, we'll also use centroid-based selection for consistency
|
| 98 |
+
if len(dmso_coords) >= target_size:
|
| 99 |
+
x0_coords = select_closest_to_centroid(dmso_coords, target_size)
|
| 100 |
+
else:
|
| 101 |
+
# If largest cluster is smaller than target, use all of it and pad with other DMSO cells
|
| 102 |
+
remaining_needed = target_size - len(dmso_coords)
|
| 103 |
+
other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster]
|
| 104 |
+
other_dmso_coords = other_dmso_data[pc_cols].values
|
| 105 |
+
|
| 106 |
+
if len(other_dmso_coords) >= remaining_needed:
|
| 107 |
+
# Select closest to centroid from other DMSO cells
|
| 108 |
+
other_selected = select_closest_to_centroid(other_dmso_coords, remaining_needed)
|
| 109 |
+
x0_coords = np.vstack([dmso_coords, other_selected])
|
| 110 |
+
else:
|
| 111 |
+
# Use all available DMSO cells and reduce target size
|
| 112 |
+
all_dmso_coords = dmso_data[pc_cols].values
|
| 113 |
+
target_size = min(target_size, len(all_dmso_coords))
|
| 114 |
+
x0_coords = select_closest_to_centroid(all_dmso_coords, target_size)
|
| 115 |
+
|
| 116 |
+
# Re-select endpoint clusters with updated target size
|
| 117 |
+
x1_1_coords = select_closest_to_centroid(x1_1_data[pc_cols].values.astype(float), target_size)
|
| 118 |
+
x1_2_coords = select_closest_to_centroid(x1_2_data[pc_cols].values.astype(float), target_size)
|
| 119 |
+
|
| 120 |
+
# No need to resample since we already selected the right number
|
| 121 |
+
# The endpoint clusters are already at target_size from centroid-based selection
|
| 122 |
+
|
| 123 |
+
self.n_samples = target_size
|
| 124 |
+
|
| 125 |
+
x0 = torch.tensor(x0_coords, dtype=torch.float32)
|
| 126 |
+
x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
|
| 127 |
+
x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
|
| 128 |
+
x1 = torch.cat([x1_1, x1_2], dim=0)
|
| 129 |
+
|
| 130 |
+
self.coords_t0 = x0
|
| 131 |
+
self.coords_t1 = x1
|
| 132 |
+
self.time_labels = np.concatenate([
|
| 133 |
+
np.zeros(len(self.coords_t0)), # t=0
|
| 134 |
+
np.ones(len(self.coords_t1)), # t=1
|
| 135 |
+
])
|
| 136 |
+
|
| 137 |
+
split_index = int(target_size * self.split_ratios[0])
|
| 138 |
+
|
| 139 |
+
if target_size - split_index < self.batch_size:
|
| 140 |
+
split_index = target_size - self.batch_size
|
| 141 |
+
print('total count is:', target_size)
|
| 142 |
+
|
| 143 |
+
train_x0 = x0[:split_index]
|
| 144 |
+
val_x0 = x0[split_index:]
|
| 145 |
+
train_x1 = x1[:split_index]
|
| 146 |
+
val_x1 = x1[split_index:]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
self.val_x0 = val_x0
|
| 150 |
+
|
| 151 |
+
train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
|
| 152 |
+
train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=1.0)
|
| 153 |
+
|
| 154 |
+
val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
|
| 155 |
+
val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=1.0)
|
| 156 |
+
|
| 157 |
+
# Updated train dataloaders to include x1_3
|
| 158 |
+
self.train_dataloaders = {
|
| 159 |
+
"x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 160 |
+
"x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
self.val_dataloaders = {
|
| 164 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
|
| 165 |
+
"x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
all_coords = df[pc_cols].dropna().values.astype(float)
|
| 169 |
+
self.dataset = torch.tensor(all_coords, dtype=torch.float32)
|
| 170 |
+
self.tree = cKDTree(all_coords)
|
| 171 |
+
|
| 172 |
+
self.test_dataloaders = {
|
| 173 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
|
| 174 |
+
"dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
# Updated metric samples - now using 4 clusters instead of 3
|
| 178 |
+
#km_all = KMeans(n_clusters=4, random_state=42).fit(self.dataset.numpy())
|
| 179 |
+
km_all = KMeans(n_clusters=2, random_state=0).fit(self.dataset.numpy())
|
| 180 |
+
|
| 181 |
+
cluster_labels = km_all.labels_
|
| 182 |
+
|
| 183 |
+
cluster_0_mask = cluster_labels == 0
|
| 184 |
+
cluster_1_mask = cluster_labels == 1
|
| 185 |
+
|
| 186 |
+
samples = self.dataset.cpu().numpy()
|
| 187 |
+
|
| 188 |
+
cluster_0_data = samples[cluster_0_mask]
|
| 189 |
+
cluster_1_data = samples[cluster_1_mask]
|
| 190 |
+
|
| 191 |
+
self.metric_samples_dataloaders = [
|
| 192 |
+
DataLoader(
|
| 193 |
+
torch.tensor(cluster_1_data, dtype=torch.float32),
|
| 194 |
+
batch_size=cluster_1_data.shape[0],
|
| 195 |
+
shuffle=False,
|
| 196 |
+
drop_last=False,
|
| 197 |
+
),
|
| 198 |
+
DataLoader(
|
| 199 |
+
torch.tensor(cluster_0_data, dtype=torch.float32),
|
| 200 |
+
batch_size=cluster_0_data.shape[0],
|
| 201 |
+
shuffle=False,
|
| 202 |
+
drop_last=False,
|
| 203 |
+
),
|
| 204 |
+
]
|
| 205 |
+
|
| 206 |
+
def train_dataloader(self):
|
| 207 |
+
combined_loaders = {
|
| 208 |
+
"train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
|
| 209 |
+
"metric_samples": CombinedLoader(
|
| 210 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 211 |
+
),
|
| 212 |
+
}
|
| 213 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 214 |
+
|
| 215 |
+
def val_dataloader(self):
|
| 216 |
+
combined_loaders = {
|
| 217 |
+
"val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
|
| 218 |
+
"metric_samples": CombinedLoader(
|
| 219 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 220 |
+
),
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def test_dataloader(self):
|
| 228 |
+
combined_loaders = {
|
| 229 |
+
"test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
|
| 230 |
+
"metric_samples": CombinedLoader(
|
| 231 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 232 |
+
),
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 236 |
+
|
| 237 |
+
def get_manifold_proj(self, points):
|
| 238 |
+
"""Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
|
| 239 |
+
return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
|
| 240 |
+
|
| 241 |
+
@staticmethod
|
| 242 |
+
def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
|
| 243 |
+
"""
|
| 244 |
+
Apply local smoothing based on k-nearest neighbors in the full dataset
|
| 245 |
+
This replaces the plane projection for 2D manifold regularization
|
| 246 |
+
"""
|
| 247 |
+
points_np = x.detach().cpu().numpy()
|
| 248 |
+
_, idx = tree.query(points_np, k=k)
|
| 249 |
+
nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
|
| 250 |
+
|
| 251 |
+
# Compute weighted average of neighbors
|
| 252 |
+
dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
|
| 253 |
+
weights = torch.exp(-dists / temp)
|
| 254 |
+
weights = weights / weights.sum(dim=1, keepdim=True)
|
| 255 |
+
|
| 256 |
+
# Weighted average of neighbors
|
| 257 |
+
smoothed = (weights * nearest_pts).sum(dim=1)
|
| 258 |
+
|
| 259 |
+
# Blend original point with smoothed version
|
| 260 |
+
alpha = 0.3 # How much smoothing to apply
|
| 261 |
+
return (1 - alpha) * x + alpha * smoothed
|
| 262 |
+
|
| 263 |
+
def get_timepoint_data(self):
|
| 264 |
+
"""Return data organized by timepoints for visualization"""
|
| 265 |
+
return {
|
| 266 |
+
't0': self.coords_t0,
|
| 267 |
+
't1': self.coords_t1,
|
| 268 |
+
'time_labels': self.time_labels
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
def get_datamodule():
|
| 272 |
+
datamodule = ClonidineSingleBranchDataModule(args)
|
| 273 |
+
datamodule.setup(stage="fit")
|
| 274 |
+
return datamodule
|
dataloaders/clonidine_v2_data.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
sys.argv = ['']
|
| 4 |
+
from sklearn.preprocessing import StandardScaler
|
| 5 |
+
import pytorch_lightning as pl
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from lightning.pytorch.utilities.combined_loader import CombinedLoader
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
from functools import partial
|
| 11 |
+
from scipy.spatial import cKDTree
|
| 12 |
+
from sklearn.cluster import KMeans
|
| 13 |
+
from torch.utils.data import TensorDataset
|
| 14 |
+
|
| 15 |
+
from train.parsers_tahoe import parse_args
|
| 16 |
+
args = parse_args()
|
| 17 |
+
|
| 18 |
+
class ClonidineV2DataModule(pl.LightningDataModule):
|
| 19 |
+
def __init__(self, args):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.save_hyperparameters()
|
| 22 |
+
|
| 23 |
+
self.batch_size = args.batch_size
|
| 24 |
+
self.max_dim = args.dim
|
| 25 |
+
self.whiten = args.whiten
|
| 26 |
+
self.split_ratios = args.split_ratios
|
| 27 |
+
|
| 28 |
+
self.dim = args.dim
|
| 29 |
+
print("dimension")
|
| 30 |
+
print(self.dim)
|
| 31 |
+
# Path to your combined data
|
| 32 |
+
self.data_path = "./data/pca_and_leiden_labels.csv"
|
| 33 |
+
self.num_timesteps = 2
|
| 34 |
+
self.args = args
|
| 35 |
+
self._prepare_data()
|
| 36 |
+
|
| 37 |
+
def _prepare_data(self):
|
| 38 |
+
df = pd.read_csv(self.data_path, comment='#')
|
| 39 |
+
df = df.iloc[:, 1:]
|
| 40 |
+
df = df.replace('', np.nan)
|
| 41 |
+
pc_cols = df.columns[:150]
|
| 42 |
+
for col in pc_cols:
|
| 43 |
+
df[col] = pd.to_numeric(df[col], errors='coerce')
|
| 44 |
+
leiden_dmso_col = 'leiden_DMSO_TF_0.0uM'
|
| 45 |
+
leiden_clonidine_col = 'leiden_Clonidine (hydrochloride)_5.0uM'
|
| 46 |
+
|
| 47 |
+
dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column
|
| 48 |
+
clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column
|
| 49 |
+
|
| 50 |
+
dmso_data = df[dmso_mask].copy()
|
| 51 |
+
clonidine_data = df[clonidine_mask].copy()
|
| 52 |
+
|
| 53 |
+
top_clonidine_clusters = ['0.0', '4.0']
|
| 54 |
+
|
| 55 |
+
x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]]
|
| 56 |
+
x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]]
|
| 57 |
+
|
| 58 |
+
x1_1_coords = x1_1_data[pc_cols].values
|
| 59 |
+
x1_2_coords = x1_2_data[pc_cols].values
|
| 60 |
+
|
| 61 |
+
x1_1_coords = x1_1_coords.astype(float)
|
| 62 |
+
x1_2_coords = x1_2_coords.astype(float)
|
| 63 |
+
|
| 64 |
+
# Target size is now the minimum across all three endpoint clusters
|
| 65 |
+
target_size = min(len(x1_1_coords), len(x1_2_coords),)
|
| 66 |
+
|
| 67 |
+
# Helper function to select points closest to centroid
|
| 68 |
+
def select_closest_to_centroid(coords, target_size):
|
| 69 |
+
if len(coords) <= target_size:
|
| 70 |
+
return coords
|
| 71 |
+
|
| 72 |
+
# Calculate centroid
|
| 73 |
+
centroid = np.mean(coords, axis=0)
|
| 74 |
+
|
| 75 |
+
# Calculate distances to centroid
|
| 76 |
+
distances = np.linalg.norm(coords - centroid, axis=1)
|
| 77 |
+
|
| 78 |
+
# Get indices of closest points
|
| 79 |
+
closest_indices = np.argsort(distances)[:target_size]
|
| 80 |
+
|
| 81 |
+
return coords[closest_indices]
|
| 82 |
+
|
| 83 |
+
# Sample all endpoint clusters to target size using centroid-based selection
|
| 84 |
+
x1_1_coords = select_closest_to_centroid(x1_1_coords, target_size)
|
| 85 |
+
x1_2_coords = select_closest_to_centroid(x1_2_coords, target_size)
|
| 86 |
+
|
| 87 |
+
dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts()
|
| 88 |
+
|
| 89 |
+
# DMSO (unchanged)
|
| 90 |
+
largest_dmso_cluster = dmso_cluster_counts.index[0]
|
| 91 |
+
dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster]
|
| 92 |
+
|
| 93 |
+
dmso_coords = dmso_cluster_data[pc_cols].values
|
| 94 |
+
|
| 95 |
+
# Random sampling from largest DMSO cluster to match target size
|
| 96 |
+
# For DMSO, we'll also use centroid-based selection for consistency
|
| 97 |
+
if len(dmso_coords) >= target_size:
|
| 98 |
+
x0_coords = select_closest_to_centroid(dmso_coords, target_size)
|
| 99 |
+
else:
|
| 100 |
+
# If largest cluster is smaller than target, use all of it and pad with other DMSO cells
|
| 101 |
+
remaining_needed = target_size - len(dmso_coords)
|
| 102 |
+
other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster]
|
| 103 |
+
other_dmso_coords = other_dmso_data[pc_cols].values
|
| 104 |
+
|
| 105 |
+
if len(other_dmso_coords) >= remaining_needed:
|
| 106 |
+
# Select closest to centroid from other DMSO cells
|
| 107 |
+
other_selected = select_closest_to_centroid(other_dmso_coords, remaining_needed)
|
| 108 |
+
x0_coords = np.vstack([dmso_coords, other_selected])
|
| 109 |
+
else:
|
| 110 |
+
# Use all available DMSO cells and reduce target size
|
| 111 |
+
all_dmso_coords = dmso_data[pc_cols].values
|
| 112 |
+
target_size = min(target_size, len(all_dmso_coords))
|
| 113 |
+
x0_coords = select_closest_to_centroid(all_dmso_coords, target_size)
|
| 114 |
+
|
| 115 |
+
# Re-select endpoint clusters with updated target size
|
| 116 |
+
x1_1_coords = select_closest_to_centroid(x1_1_data[pc_cols].values.astype(float), target_size)
|
| 117 |
+
x1_2_coords = select_closest_to_centroid(x1_2_data[pc_cols].values.astype(float), target_size)
|
| 118 |
+
|
| 119 |
+
# No need to resample since we already selected the right number
|
| 120 |
+
# The endpoint clusters are already at target_size from centroid-based selection
|
| 121 |
+
|
| 122 |
+
self.n_samples = target_size
|
| 123 |
+
|
| 124 |
+
x0 = torch.tensor(x0_coords, dtype=torch.float32)
|
| 125 |
+
x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
|
| 126 |
+
x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
|
| 127 |
+
|
| 128 |
+
self.coords_t0 = x0
|
| 129 |
+
self.coords_t1_1 = x1_1
|
| 130 |
+
self.coords_t1_2 = x1_2
|
| 131 |
+
self.time_labels = np.concatenate([
|
| 132 |
+
np.zeros(len(self.coords_t0)), # t=0
|
| 133 |
+
np.ones(len(self.coords_t1_1)), # t=1
|
| 134 |
+
np.ones(len(self.coords_t1_2)),
|
| 135 |
+
])
|
| 136 |
+
|
| 137 |
+
split_index = int(target_size * self.split_ratios[0])
|
| 138 |
+
|
| 139 |
+
if target_size - split_index < self.batch_size:
|
| 140 |
+
split_index = target_size - self.batch_size
|
| 141 |
+
print('total count is:', target_size)
|
| 142 |
+
|
| 143 |
+
train_x0 = x0[:split_index]
|
| 144 |
+
val_x0 = x0[split_index:]
|
| 145 |
+
train_x1_1 = x1_1[:split_index]
|
| 146 |
+
val_x1_1 = x1_1[split_index:]
|
| 147 |
+
train_x1_2 = x1_2[:split_index]
|
| 148 |
+
val_x1_2 = x1_2[split_index:]
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
self.val_x0 = val_x0
|
| 152 |
+
|
| 153 |
+
train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
|
| 154 |
+
train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.5)
|
| 155 |
+
train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.5)
|
| 156 |
+
|
| 157 |
+
val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
|
| 158 |
+
val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.5)
|
| 159 |
+
val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.5)
|
| 160 |
+
|
| 161 |
+
# Updated train dataloaders to include x1_3
|
| 162 |
+
self.train_dataloaders = {
|
| 163 |
+
"x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 164 |
+
"x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 165 |
+
"x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
self.val_dataloaders = {
|
| 169 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
|
| 170 |
+
"x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 171 |
+
"x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
all_coords = df[pc_cols].dropna().values.astype(float)
|
| 175 |
+
self.dataset = torch.tensor(all_coords, dtype=torch.float32)
|
| 176 |
+
self.tree = cKDTree(all_coords)
|
| 177 |
+
|
| 178 |
+
self.test_dataloaders = {
|
| 179 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
|
| 180 |
+
"dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
km_all = KMeans(n_clusters=3, random_state=0).fit(self.dataset.numpy())
|
| 184 |
+
|
| 185 |
+
cluster_labels = km_all.labels_
|
| 186 |
+
|
| 187 |
+
cluster_0_mask = cluster_labels == 0
|
| 188 |
+
cluster_1_mask = cluster_labels == 1
|
| 189 |
+
cluster_2_mask = cluster_labels == 2
|
| 190 |
+
|
| 191 |
+
samples = self.dataset.cpu().numpy()
|
| 192 |
+
|
| 193 |
+
cluster_0_data = samples[cluster_0_mask]
|
| 194 |
+
cluster_1_data = samples[cluster_1_mask]
|
| 195 |
+
cluster_2_data = samples[cluster_2_mask]
|
| 196 |
+
|
| 197 |
+
self.metric_samples_dataloaders = [
|
| 198 |
+
DataLoader(
|
| 199 |
+
torch.tensor(cluster_2_data, dtype=torch.float32),
|
| 200 |
+
batch_size=cluster_2_data.shape[0],
|
| 201 |
+
shuffle=False,
|
| 202 |
+
drop_last=False,
|
| 203 |
+
),
|
| 204 |
+
DataLoader(
|
| 205 |
+
torch.tensor(cluster_0_data, dtype=torch.float32),
|
| 206 |
+
batch_size=cluster_0_data.shape[0],
|
| 207 |
+
shuffle=False,
|
| 208 |
+
drop_last=False,
|
| 209 |
+
),
|
| 210 |
+
|
| 211 |
+
DataLoader(
|
| 212 |
+
torch.tensor(cluster_1_data, dtype=torch.float32),
|
| 213 |
+
batch_size=cluster_1_data.shape[0],
|
| 214 |
+
shuffle=False,
|
| 215 |
+
drop_last=False,
|
| 216 |
+
),
|
| 217 |
+
]
|
| 218 |
+
|
| 219 |
+
def train_dataloader(self):
|
| 220 |
+
combined_loaders = {
|
| 221 |
+
"train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
|
| 222 |
+
"metric_samples": CombinedLoader(
|
| 223 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 224 |
+
),
|
| 225 |
+
}
|
| 226 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 227 |
+
|
| 228 |
+
def val_dataloader(self):
|
| 229 |
+
combined_loaders = {
|
| 230 |
+
"val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
|
| 231 |
+
"metric_samples": CombinedLoader(
|
| 232 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 233 |
+
),
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def test_dataloader(self):
|
| 240 |
+
combined_loaders = {
|
| 241 |
+
"test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
|
| 242 |
+
"metric_samples": CombinedLoader(
|
| 243 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 244 |
+
),
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 248 |
+
|
| 249 |
+
def get_manifold_proj(self, points):
|
| 250 |
+
"""Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
|
| 251 |
+
return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
|
| 252 |
+
|
| 253 |
+
@staticmethod
|
| 254 |
+
def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
|
| 255 |
+
"""
|
| 256 |
+
Apply local smoothing based on k-nearest neighbors in the full dataset
|
| 257 |
+
This replaces the plane projection for 2D manifold regularization
|
| 258 |
+
"""
|
| 259 |
+
points_np = x.detach().cpu().numpy()
|
| 260 |
+
_, idx = tree.query(points_np, k=k)
|
| 261 |
+
nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
|
| 262 |
+
|
| 263 |
+
# Compute weighted average of neighbors
|
| 264 |
+
dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
|
| 265 |
+
weights = torch.exp(-dists / temp)
|
| 266 |
+
weights = weights / weights.sum(dim=1, keepdim=True)
|
| 267 |
+
|
| 268 |
+
# Weighted average of neighbors
|
| 269 |
+
smoothed = (weights * nearest_pts).sum(dim=1)
|
| 270 |
+
|
| 271 |
+
# Blend original point with smoothed version
|
| 272 |
+
alpha = 0.3 # How much smoothing to apply
|
| 273 |
+
return (1 - alpha) * x + alpha * smoothed
|
| 274 |
+
|
| 275 |
+
def get_timepoint_data(self):
|
| 276 |
+
"""Return data organized by timepoints for visualization"""
|
| 277 |
+
return {
|
| 278 |
+
't0': self.coords_t0,
|
| 279 |
+
't1_1': self.coords_t1_1,
|
| 280 |
+
't1_2': self.coords_t1_2,
|
| 281 |
+
'time_labels': self.time_labels
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
def get_datamodule():
|
| 285 |
+
datamodule = ClonidineV2DataModule(args)
|
| 286 |
+
datamodule.setup(stage="fit")
|
| 287 |
+
return datamodule
|
dataloaders/lidar_data.py
ADDED
|
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
sys.argv = ['']
|
| 4 |
+
from sklearn.preprocessing import StandardScaler
|
| 5 |
+
import pytorch_lightning as pl
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from pytorch_lightning.utilities.combined_loader import CombinedLoader
|
| 8 |
+
import laspy
|
| 9 |
+
import numpy as np
|
| 10 |
+
from scipy.spatial import cKDTree
|
| 11 |
+
import math
|
| 12 |
+
from functools import partial
|
| 13 |
+
from torch.utils.data import TensorDataset
|
| 14 |
+
|
| 15 |
+
#from train.parsers import parse_args
|
| 16 |
+
#args = parse_args()
|
| 17 |
+
|
| 18 |
+
class GaussianMM:
|
| 19 |
+
def __init__(self, mu, var):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.centers = torch.tensor(mu)
|
| 22 |
+
self.logstd = torch.tensor(var).log() / 2.0
|
| 23 |
+
self.K = self.centers.shape[0]
|
| 24 |
+
|
| 25 |
+
def logprob(self, x):
|
| 26 |
+
logprobs = self.normal_logprob(
|
| 27 |
+
x.unsqueeze(1), self.centers.unsqueeze(0), self.logstd
|
| 28 |
+
)
|
| 29 |
+
logprobs = torch.sum(logprobs, dim=2)
|
| 30 |
+
return torch.logsumexp(logprobs, dim=1) - math.log(self.K)
|
| 31 |
+
|
| 32 |
+
def normal_logprob(self, z, mean, log_std):
|
| 33 |
+
mean = mean + torch.tensor(0.0)
|
| 34 |
+
log_std = log_std + torch.tensor(0.0)
|
| 35 |
+
c = torch.tensor([math.log(2 * math.pi)]).to(z)
|
| 36 |
+
inv_sigma = torch.exp(-log_std)
|
| 37 |
+
tmp = (z - mean) * inv_sigma
|
| 38 |
+
return -0.5 * (tmp * tmp + 2 * log_std + c)
|
| 39 |
+
|
| 40 |
+
def __call__(self, n_samples):
|
| 41 |
+
idx = torch.randint(self.K, (n_samples,)).to(self.centers.device)
|
| 42 |
+
mean = self.centers[idx]
|
| 43 |
+
return torch.randn(*mean.shape).to(mean) * torch.exp(self.logstd) + mean
|
| 44 |
+
|
| 45 |
+
class BranchedLidarDataModule(pl.LightningDataModule):
|
| 46 |
+
def __init__(self, args):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.save_hyperparameters()
|
| 49 |
+
|
| 50 |
+
self.data_path = args.data_path
|
| 51 |
+
self.batch_size = args.batch_size
|
| 52 |
+
self.max_dim = args.dim
|
| 53 |
+
self.whiten = args.whiten
|
| 54 |
+
self.p0_mu = [
|
| 55 |
+
[-4.5, -4.0, 0.5],
|
| 56 |
+
[-4.2, -3.5, 0.5],
|
| 57 |
+
[-4.0, -3.0, 0.5],
|
| 58 |
+
[-3.75, -2.5, 0.5],
|
| 59 |
+
]
|
| 60 |
+
self.p0_var = 0.02
|
| 61 |
+
|
| 62 |
+
self.p1_1_mu = [
|
| 63 |
+
[-2.5, -0.25, 0.5],
|
| 64 |
+
[-2.25, 0.675, 0.5],
|
| 65 |
+
[-2, 1.5, 0.5],
|
| 66 |
+
]
|
| 67 |
+
self.p1_2_mu = [
|
| 68 |
+
[2, -2, 0.5],
|
| 69 |
+
[2.6, -1.25, 0.5],
|
| 70 |
+
[3.2, -0.5, 0.5]
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
self.p1_var = 0.03
|
| 74 |
+
self.k = 20
|
| 75 |
+
self.n_samples = 5000
|
| 76 |
+
self.num_timesteps = 2
|
| 77 |
+
self.split_ratios = args.split_ratios
|
| 78 |
+
self._prepare_data()
|
| 79 |
+
|
| 80 |
+
def assign_region(self):
|
| 81 |
+
all_centers = {
|
| 82 |
+
0: torch.tensor(self.p0_mu), # Region 0: p0
|
| 83 |
+
1: torch.tensor(self.p1_1_mu), # Region 1: p1_1
|
| 84 |
+
2: torch.tensor(self.p1_2_mu), # Region 2: p1_2
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
dataset = self.dataset.to(torch.float32)
|
| 88 |
+
N = dataset.shape[0]
|
| 89 |
+
assignments = torch.zeros(N, dtype=torch.long)
|
| 90 |
+
|
| 91 |
+
# For each point, compute min distance to each region's centers
|
| 92 |
+
for i in range(N):
|
| 93 |
+
point = dataset[i]
|
| 94 |
+
min_dist = float("inf")
|
| 95 |
+
best_region = 0
|
| 96 |
+
for region, centers in all_centers.items():
|
| 97 |
+
dists = ((centers - point)**2).sum(dim=1)
|
| 98 |
+
region_min = dists.min()
|
| 99 |
+
if region_min < min_dist:
|
| 100 |
+
min_dist = region_min
|
| 101 |
+
best_region = region
|
| 102 |
+
assignments[i] = best_region
|
| 103 |
+
return assignments
|
| 104 |
+
|
| 105 |
+
def _prepare_data(self):
|
| 106 |
+
las = laspy.read(self.data_path)
|
| 107 |
+
# Extract only "ground" points.
|
| 108 |
+
self.mask = las.classification == 2
|
| 109 |
+
# Original Preprocessing
|
| 110 |
+
x_offset, x_scale = las.header.offsets[0], las.header.scales[0]
|
| 111 |
+
y_offset, y_scale = las.header.offsets[1], las.header.scales[1]
|
| 112 |
+
z_offset, z_scale = las.header.offsets[2], las.header.scales[2]
|
| 113 |
+
dataset = np.vstack(
|
| 114 |
+
(
|
| 115 |
+
las.X[self.mask] * x_scale + x_offset,
|
| 116 |
+
las.Y[self.mask] * y_scale + y_offset,
|
| 117 |
+
las.Z[self.mask] * z_scale + z_offset,
|
| 118 |
+
)
|
| 119 |
+
).transpose()
|
| 120 |
+
mi = dataset.min(axis=0, keepdims=True)
|
| 121 |
+
ma = dataset.max(axis=0, keepdims=True)
|
| 122 |
+
dataset = (dataset - mi) / (ma - mi) * [10.0, 10.0, 2.0] + [-5.0, -5.0, 0.0]
|
| 123 |
+
|
| 124 |
+
self.dataset = torch.tensor(dataset, dtype=torch.float32)
|
| 125 |
+
self.tree = cKDTree(dataset)
|
| 126 |
+
|
| 127 |
+
x0_gaussian = GaussianMM(self.p0_mu, self.p0_var)(self.n_samples)
|
| 128 |
+
x1_1_gaussian = GaussianMM(self.p1_1_mu, self.p1_var)(self.n_samples)
|
| 129 |
+
x1_2_gaussian = GaussianMM(self.p1_2_mu, self.p1_var)(self.n_samples)
|
| 130 |
+
|
| 131 |
+
x0 = self.get_tangent_proj(x0_gaussian)(x0_gaussian)
|
| 132 |
+
x1_1 = self.get_tangent_proj(x1_1_gaussian)(x1_1_gaussian)
|
| 133 |
+
x1_2 = self.get_tangent_proj(x1_2_gaussian)(x1_2_gaussian)
|
| 134 |
+
|
| 135 |
+
split_index = int(self.n_samples * self.split_ratios[0])
|
| 136 |
+
|
| 137 |
+
self.scaler = StandardScaler()
|
| 138 |
+
if self.whiten:
|
| 139 |
+
self.dataset = torch.tensor(
|
| 140 |
+
self.scaler.fit_transform(dataset), dtype=torch.float32
|
| 141 |
+
)
|
| 142 |
+
x0 = torch.tensor(self.scaler.transform(x0), dtype=torch.float32)
|
| 143 |
+
x1_1 = torch.tensor(self.scaler.transform(x1_1), dtype=torch.float32)
|
| 144 |
+
x1_2 = torch.tensor(self.scaler.transform(x1_2), dtype=torch.float32)
|
| 145 |
+
|
| 146 |
+
train_x0 = x0[:split_index]
|
| 147 |
+
val_x0 = x0[split_index:]
|
| 148 |
+
|
| 149 |
+
# branches
|
| 150 |
+
train_x1_1 = x1_1[:split_index]
|
| 151 |
+
print("train_x1_1")
|
| 152 |
+
print(train_x1_1.shape)
|
| 153 |
+
val_x1_1 = x1_1[split_index:]
|
| 154 |
+
train_x1_2 = x1_2[:split_index]
|
| 155 |
+
val_x1_2 = x1_2[split_index:]
|
| 156 |
+
|
| 157 |
+
self.val_x0 = val_x0
|
| 158 |
+
|
| 159 |
+
# Adjust split_index to ensure minimum validation samples
|
| 160 |
+
if self.n_samples - split_index < self.batch_size:
|
| 161 |
+
split_index = self.n_samples - self.batch_size
|
| 162 |
+
|
| 163 |
+
self.train_dataloaders = {
|
| 164 |
+
"x0": DataLoader(train_x0, batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 165 |
+
"x1_1": DataLoader(train_x1_1, batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 166 |
+
"x1_2": DataLoader(train_x1_2, batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 167 |
+
}
|
| 168 |
+
self.val_dataloaders = {
|
| 169 |
+
"x0": DataLoader(val_x0, batch_size=self.batch_size, shuffle=False, drop_last=True),
|
| 170 |
+
"x1_1": DataLoader(val_x1_1, batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 171 |
+
"x1_2": DataLoader(val_x1_2, batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 172 |
+
}
|
| 173 |
+
# to edit?
|
| 174 |
+
self.test_dataloaders = [
|
| 175 |
+
DataLoader(
|
| 176 |
+
self.val_x0,
|
| 177 |
+
batch_size=self.val_x0.shape[0],
|
| 178 |
+
shuffle=False,
|
| 179 |
+
drop_last=False,
|
| 180 |
+
),
|
| 181 |
+
DataLoader(
|
| 182 |
+
self.dataset,
|
| 183 |
+
batch_size=self.dataset.shape[0],
|
| 184 |
+
shuffle=False,
|
| 185 |
+
drop_last=False,
|
| 186 |
+
),
|
| 187 |
+
]
|
| 188 |
+
|
| 189 |
+
points = self.dataset.cpu().numpy()
|
| 190 |
+
x, y = points[:, 0], points[:, 1]
|
| 191 |
+
# Diagonal-based coordinates (rotated 45°)
|
| 192 |
+
u = (x + y) / np.sqrt(2) # along x=y
|
| 193 |
+
# start region (A) using u
|
| 194 |
+
u_thresh = np.percentile(u, 30) # tweak this threshold to control size
|
| 195 |
+
mask_A = u <= u_thresh
|
| 196 |
+
|
| 197 |
+
# among the rest, split by x=y diagonal
|
| 198 |
+
remaining = ~mask_A
|
| 199 |
+
mask_B = remaining & (x < y) # left of diagonal
|
| 200 |
+
mask_C = remaining & (x >= y) # right of diagonal
|
| 201 |
+
|
| 202 |
+
# Assign dataloaders
|
| 203 |
+
self.metric_samples_dataloaders = [
|
| 204 |
+
DataLoader(torch.tensor(points[mask_A], dtype=torch.float32), batch_size=points[mask_A].shape[0], shuffle=False),
|
| 205 |
+
DataLoader(torch.tensor(points[mask_B], dtype=torch.float32), batch_size=points[mask_B].shape[0], shuffle=False),
|
| 206 |
+
DataLoader(torch.tensor(points[mask_C], dtype=torch.float32), batch_size=points[mask_C].shape[0], shuffle=False),
|
| 207 |
+
]
|
| 208 |
+
|
| 209 |
+
def train_dataloader(self):
|
| 210 |
+
combined_loaders = {
|
| 211 |
+
"train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
|
| 212 |
+
"metric_samples": CombinedLoader(
|
| 213 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 214 |
+
),
|
| 215 |
+
}
|
| 216 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 217 |
+
|
| 218 |
+
def val_dataloader(self):
|
| 219 |
+
combined_loaders = {
|
| 220 |
+
"val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
|
| 221 |
+
"metric_samples": CombinedLoader(
|
| 222 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 223 |
+
),
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 227 |
+
|
| 228 |
+
def test_dataloader(self):
|
| 229 |
+
return CombinedLoader(self.test_dataloaders)
|
| 230 |
+
|
| 231 |
+
def get_tangent_proj(self, points):
|
| 232 |
+
w = self.get_tangent_plane(points)
|
| 233 |
+
return partial(BranchedLidarDataModule.projection_op, w=w)
|
| 234 |
+
|
| 235 |
+
def get_tangent_plane(self, points, temp=1e-3):
|
| 236 |
+
points_np = points.detach().cpu().numpy()
|
| 237 |
+
_, idx = self.tree.query(points_np, k=self.k)
|
| 238 |
+
nearest_pts = self.dataset[idx]
|
| 239 |
+
nearest_pts = torch.tensor(nearest_pts).to(points)
|
| 240 |
+
|
| 241 |
+
dists = (points.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
|
| 242 |
+
weights = torch.exp(-dists / temp)
|
| 243 |
+
|
| 244 |
+
# Fits plane with least vertical distance.
|
| 245 |
+
w = BranchedLidarDataModule.fit_plane(nearest_pts, weights)
|
| 246 |
+
return w
|
| 247 |
+
|
| 248 |
+
@staticmethod
|
| 249 |
+
def fit_plane(points, weights=None):
|
| 250 |
+
"""Expects points to be of shape (..., 3).
|
| 251 |
+
Returns [a, b, c] such that the plane is defined as
|
| 252 |
+
ax + by + c = z
|
| 253 |
+
"""
|
| 254 |
+
D = torch.cat([points[..., :2], torch.ones_like(points[..., 2:3])], dim=-1)
|
| 255 |
+
z = points[..., 2]
|
| 256 |
+
if weights is not None:
|
| 257 |
+
Dtrans = D.transpose(-1, -2)
|
| 258 |
+
else:
|
| 259 |
+
DW = D * weights
|
| 260 |
+
Dtrans = DW.transpose(-1, -2)
|
| 261 |
+
w = torch.linalg.solve(
|
| 262 |
+
torch.matmul(Dtrans, D), torch.matmul(Dtrans, z.unsqueeze(-1))
|
| 263 |
+
).squeeze(-1)
|
| 264 |
+
return w
|
| 265 |
+
|
| 266 |
+
@staticmethod
|
| 267 |
+
def projection_op(x, w):
|
| 268 |
+
"""Projects points to a plane defined by w."""
|
| 269 |
+
# Normal vector to the tangent plane.
|
| 270 |
+
n = torch.cat([w[..., :2], -torch.ones_like(w[..., 2:3])], dim=1)
|
| 271 |
+
|
| 272 |
+
pn = torch.sum(x * n, dim=-1, keepdim=True)
|
| 273 |
+
nn = torch.sum(n * n, dim=-1, keepdim=True)
|
| 274 |
+
|
| 275 |
+
# Offset.
|
| 276 |
+
d = w[..., 2:3]
|
| 277 |
+
|
| 278 |
+
# Projection of x onto n.
|
| 279 |
+
projn_x = ((pn + d) / nn) * n
|
| 280 |
+
|
| 281 |
+
# Remove component in the normal direction.
|
| 282 |
+
return x - projn_x
|
| 283 |
+
|
| 284 |
+
class WeightedBranchedLidarDataModule(pl.LightningDataModule):
|
| 285 |
+
def __init__(self, args):
|
| 286 |
+
super().__init__()
|
| 287 |
+
self.save_hyperparameters()
|
| 288 |
+
|
| 289 |
+
self.data_path = args.data_path
|
| 290 |
+
self.batch_size = args.batch_size
|
| 291 |
+
self.max_dim = args.dim
|
| 292 |
+
self.whiten = args.whiten
|
| 293 |
+
self.p0_mu = [
|
| 294 |
+
[-4.5, -4.0, 0.5],
|
| 295 |
+
[-4.2, -3.5, 0.5],
|
| 296 |
+
[-4.0, -3.0, 0.5],
|
| 297 |
+
[-3.75, -2.5, 0.5],
|
| 298 |
+
]
|
| 299 |
+
self.p0_var = 0.02
|
| 300 |
+
# multiple p1 for each branch
|
| 301 |
+
#changed
|
| 302 |
+
self.p1_1_mu = [
|
| 303 |
+
[-2.5, -0.25, 0.5],
|
| 304 |
+
[-2.25, 0.675, 0.5],
|
| 305 |
+
[-2, 1.5, 0.5],
|
| 306 |
+
]
|
| 307 |
+
self.p1_2_mu = [
|
| 308 |
+
[2, -2, 0.5],
|
| 309 |
+
[2.6, -1.25, 0.5],
|
| 310 |
+
[3.2, -0.5, 0.5]
|
| 311 |
+
]
|
| 312 |
+
|
| 313 |
+
self.p1_var = 0.03
|
| 314 |
+
self.k = 20
|
| 315 |
+
self.n_samples = 5000
|
| 316 |
+
self.num_timesteps = 2
|
| 317 |
+
self.split_ratios = args.split_ratios
|
| 318 |
+
|
| 319 |
+
self.num_timesteps = 2
|
| 320 |
+
self.metric_clusters = 3
|
| 321 |
+
self.args = args
|
| 322 |
+
self._prepare_data()
|
| 323 |
+
|
| 324 |
+
def _prepare_data(self):
|
| 325 |
+
las = laspy.read(self.data_path)
|
| 326 |
+
# Extract only "ground" points.
|
| 327 |
+
self.mask = las.classification == 2
|
| 328 |
+
# Original Preprocessing
|
| 329 |
+
x_offset, x_scale = las.header.offsets[0], las.header.scales[0]
|
| 330 |
+
y_offset, y_scale = las.header.offsets[1], las.header.scales[1]
|
| 331 |
+
z_offset, z_scale = las.header.offsets[2], las.header.scales[2]
|
| 332 |
+
dataset = np.vstack(
|
| 333 |
+
(
|
| 334 |
+
las.X[self.mask] * x_scale + x_offset,
|
| 335 |
+
las.Y[self.mask] * y_scale + y_offset,
|
| 336 |
+
las.Z[self.mask] * z_scale + z_offset,
|
| 337 |
+
)
|
| 338 |
+
).transpose()
|
| 339 |
+
mi = dataset.min(axis=0, keepdims=True)
|
| 340 |
+
ma = dataset.max(axis=0, keepdims=True)
|
| 341 |
+
dataset = (dataset - mi) / (ma - mi) * [10.0, 10.0, 2.0] + [-5.0, -5.0, 0.0]
|
| 342 |
+
|
| 343 |
+
self.dataset = torch.tensor(dataset, dtype=torch.float32)
|
| 344 |
+
self.tree = cKDTree(dataset)
|
| 345 |
+
|
| 346 |
+
x0_gaussian = GaussianMM(self.p0_mu, self.p0_var)(self.n_samples)
|
| 347 |
+
x1_1_gaussian = GaussianMM(self.p1_1_mu, self.p1_var)(self.n_samples)
|
| 348 |
+
x1_2_gaussian = GaussianMM(self.p1_2_mu, self.p1_var)(self.n_samples)
|
| 349 |
+
|
| 350 |
+
x0 = self.get_tangent_proj(x0_gaussian)(x0_gaussian)
|
| 351 |
+
x1_1 = self.get_tangent_proj(x1_1_gaussian)(x1_1_gaussian)
|
| 352 |
+
x1_2 = self.get_tangent_proj(x1_2_gaussian)(x1_2_gaussian)
|
| 353 |
+
|
| 354 |
+
split_index = int(self.n_samples * self.split_ratios[0])
|
| 355 |
+
|
| 356 |
+
self.scaler = StandardScaler()
|
| 357 |
+
if self.whiten:
|
| 358 |
+
self.dataset = torch.tensor(
|
| 359 |
+
self.scaler.fit_transform(dataset), dtype=torch.float32
|
| 360 |
+
)
|
| 361 |
+
x0 = torch.tensor(self.scaler.transform(x0), dtype=torch.float32)
|
| 362 |
+
x1_1 = torch.tensor(self.scaler.transform(x1_1), dtype=torch.float32)
|
| 363 |
+
x1_2 = torch.tensor(self.scaler.transform(x1_2), dtype=torch.float32)
|
| 364 |
+
|
| 365 |
+
self.coords_t0 = x0
|
| 366 |
+
self.coords_t1_1 = x1_1
|
| 367 |
+
self.coords_t1_2 = x1_2
|
| 368 |
+
self.time_labels = np.concatenate([
|
| 369 |
+
np.zeros(len(self.coords_t0)), # t=0
|
| 370 |
+
np.ones(len(self.coords_t1_1)), # t=1
|
| 371 |
+
np.ones(len(self.coords_t1_2)), # t=1
|
| 372 |
+
])
|
| 373 |
+
|
| 374 |
+
train_x0 = x0[:split_index]
|
| 375 |
+
val_x0 = x0[split_index:]
|
| 376 |
+
|
| 377 |
+
# branches
|
| 378 |
+
train_x1_1 = x1_1[:split_index]
|
| 379 |
+
|
| 380 |
+
val_x1_1 = x1_1[split_index:]
|
| 381 |
+
train_x1_2 = x1_2[:split_index]
|
| 382 |
+
val_x1_2 = x1_2[split_index:]
|
| 383 |
+
|
| 384 |
+
self.val_x0 = val_x0
|
| 385 |
+
|
| 386 |
+
train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
|
| 387 |
+
train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.5)
|
| 388 |
+
train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.5)
|
| 389 |
+
|
| 390 |
+
val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
|
| 391 |
+
val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.5)
|
| 392 |
+
val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.5)
|
| 393 |
+
|
| 394 |
+
# Adjust split_index to ensure minimum validation samples
|
| 395 |
+
if self.n_samples - split_index < self.batch_size:
|
| 396 |
+
split_index = self.n_samples - self.batch_size
|
| 397 |
+
|
| 398 |
+
self.train_dataloaders = {
|
| 399 |
+
"x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 400 |
+
"x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 401 |
+
"x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
self.val_dataloaders = {
|
| 405 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
|
| 406 |
+
"x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 407 |
+
"x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
# to edit?
|
| 411 |
+
self.test_dataloaders = {
|
| 412 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
|
| 413 |
+
"x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=True),
|
| 414 |
+
"x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=True),
|
| 415 |
+
"dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
points = self.dataset.cpu().numpy()
|
| 419 |
+
x, y = points[:, 0], points[:, 1]
|
| 420 |
+
# Diagonal-based coordinates (rotated 45°)
|
| 421 |
+
u = (x + y) / np.sqrt(2) # along x=y
|
| 422 |
+
# start region (A) using u
|
| 423 |
+
u_thresh = np.percentile(u, 30) # tweak this threshold to control size
|
| 424 |
+
mask_A = u <= u_thresh
|
| 425 |
+
|
| 426 |
+
# among the rest, split by x=y diagonal
|
| 427 |
+
remaining = ~mask_A
|
| 428 |
+
mask_B = remaining & (x < y) # left of diagonal
|
| 429 |
+
mask_C = remaining & (x >= y) # right of diagonal
|
| 430 |
+
|
| 431 |
+
# Assign dataloaders
|
| 432 |
+
self.metric_samples_dataloaders = [
|
| 433 |
+
DataLoader(torch.tensor(points[mask_A], dtype=torch.float32), batch_size=points[mask_A].shape[0], shuffle=False),
|
| 434 |
+
DataLoader(torch.tensor(points[mask_B], dtype=torch.float32), batch_size=points[mask_B].shape[0], shuffle=False),
|
| 435 |
+
DataLoader(torch.tensor(points[mask_C], dtype=torch.float32), batch_size=points[mask_C].shape[0], shuffle=False),
|
| 436 |
+
]
|
| 437 |
+
|
| 438 |
+
def train_dataloader(self):
|
| 439 |
+
combined_loaders = {
|
| 440 |
+
"train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
|
| 441 |
+
"metric_samples": CombinedLoader(
|
| 442 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 443 |
+
),
|
| 444 |
+
}
|
| 445 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 446 |
+
|
| 447 |
+
def val_dataloader(self):
|
| 448 |
+
combined_loaders = {
|
| 449 |
+
"val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
|
| 450 |
+
"metric_samples": CombinedLoader(
|
| 451 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 452 |
+
),
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 456 |
+
|
| 457 |
+
def test_dataloader(self):
|
| 458 |
+
combined_loaders = {
|
| 459 |
+
"test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
|
| 460 |
+
"metric_samples": CombinedLoader(
|
| 461 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 462 |
+
),
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 466 |
+
|
| 467 |
+
def get_tangent_proj(self, points):
|
| 468 |
+
w = self.get_tangent_plane(points)
|
| 469 |
+
return partial(BranchedLidarDataModule.projection_op, w=w)
|
| 470 |
+
|
| 471 |
+
def get_tangent_plane(self, points, temp=1e-3):
|
| 472 |
+
points_np = points.detach().cpu().numpy()
|
| 473 |
+
_, idx = self.tree.query(points_np, k=self.k)
|
| 474 |
+
nearest_pts = self.dataset[idx]
|
| 475 |
+
nearest_pts = torch.tensor(nearest_pts).to(points)
|
| 476 |
+
|
| 477 |
+
dists = (points.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
|
| 478 |
+
weights = torch.exp(-dists / temp)
|
| 479 |
+
|
| 480 |
+
# Fits plane with least vertical distance.
|
| 481 |
+
w = BranchedLidarDataModule.fit_plane(nearest_pts, weights)
|
| 482 |
+
return w
|
| 483 |
+
|
| 484 |
+
@staticmethod
|
| 485 |
+
def fit_plane(points, weights=None):
|
| 486 |
+
"""Expects points to be of shape (..., 3).
|
| 487 |
+
Returns [a, b, c] such that the plane is defined as
|
| 488 |
+
ax + by + c = z
|
| 489 |
+
"""
|
| 490 |
+
D = torch.cat([points[..., :2], torch.ones_like(points[..., 2:3])], dim=-1)
|
| 491 |
+
z = points[..., 2]
|
| 492 |
+
if weights is not None:
|
| 493 |
+
Dtrans = D.transpose(-1, -2)
|
| 494 |
+
else:
|
| 495 |
+
DW = D * weights
|
| 496 |
+
Dtrans = DW.transpose(-1, -2)
|
| 497 |
+
w = torch.linalg.solve(
|
| 498 |
+
torch.matmul(Dtrans, D), torch.matmul(Dtrans, z.unsqueeze(-1))
|
| 499 |
+
).squeeze(-1)
|
| 500 |
+
return w
|
| 501 |
+
|
| 502 |
+
@staticmethod
|
| 503 |
+
def projection_op(x, w):
|
| 504 |
+
"""Projects points to a plane defined by w."""
|
| 505 |
+
# Normal vector to the tangent plane.
|
| 506 |
+
n = torch.cat([w[..., :2], -torch.ones_like(w[..., 2:3])], dim=1)
|
| 507 |
+
|
| 508 |
+
pn = torch.sum(x * n, dim=-1, keepdim=True)
|
| 509 |
+
nn = torch.sum(n * n, dim=-1, keepdim=True)
|
| 510 |
+
|
| 511 |
+
# Offset.
|
| 512 |
+
d = w[..., 2:3]
|
| 513 |
+
|
| 514 |
+
# Projection of x onto n.
|
| 515 |
+
projn_x = ((pn + d) / nn) * n
|
| 516 |
+
|
| 517 |
+
# Remove component in the normal direction.
|
| 518 |
+
return x - projn_x
|
| 519 |
+
|
| 520 |
+
def get_timepoint_data(self):
|
| 521 |
+
"""Return data organized by timepoints for visualization"""
|
| 522 |
+
return {
|
| 523 |
+
't0': self.coords_t0,
|
| 524 |
+
't1_1': self.coords_t1_1,
|
| 525 |
+
't1_2': self.coords_t1_2,
|
| 526 |
+
'time_labels': self.time_labels
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
def get_datamodule():
|
| 530 |
+
datamodule = WeightedBranchedLidarDataModule(args)
|
| 531 |
+
datamodule.setup(stage="fit")
|
| 532 |
+
return datamodule
|
dataloaders/lidar_data_single.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
sys.argv = ['']
|
| 4 |
+
from sklearn.preprocessing import StandardScaler
|
| 5 |
+
import pytorch_lightning as pl
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from pytorch_lightning.utilities.combined_loader import CombinedLoader
|
| 8 |
+
import laspy
|
| 9 |
+
import numpy as np
|
| 10 |
+
from scipy.spatial import cKDTree
|
| 11 |
+
import math
|
| 12 |
+
from functools import partial
|
| 13 |
+
from torch.utils.data import TensorDataset
|
| 14 |
+
|
| 15 |
+
from train.parsers import parse_args
|
| 16 |
+
args = parse_args()
|
| 17 |
+
|
| 18 |
+
class GaussianMM:
|
| 19 |
+
def __init__(self, mu, var):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.centers = torch.tensor(mu)
|
| 22 |
+
self.logstd = torch.tensor(var).log() / 2.0
|
| 23 |
+
self.K = self.centers.shape[0]
|
| 24 |
+
|
| 25 |
+
def logprob(self, x):
|
| 26 |
+
logprobs = self.normal_logprob(
|
| 27 |
+
x.unsqueeze(1), self.centers.unsqueeze(0), self.logstd
|
| 28 |
+
)
|
| 29 |
+
logprobs = torch.sum(logprobs, dim=2)
|
| 30 |
+
return torch.logsumexp(logprobs, dim=1) - math.log(self.K)
|
| 31 |
+
|
| 32 |
+
def normal_logprob(self, z, mean, log_std):
|
| 33 |
+
mean = mean + torch.tensor(0.0)
|
| 34 |
+
log_std = log_std + torch.tensor(0.0)
|
| 35 |
+
c = torch.tensor([math.log(2 * math.pi)]).to(z)
|
| 36 |
+
inv_sigma = torch.exp(-log_std)
|
| 37 |
+
tmp = (z - mean) * inv_sigma
|
| 38 |
+
return -0.5 * (tmp * tmp + 2 * log_std + c)
|
| 39 |
+
|
| 40 |
+
def __call__(self, n_samples):
|
| 41 |
+
idx = torch.randint(self.K, (n_samples,)).to(self.centers.device)
|
| 42 |
+
mean = self.centers[idx]
|
| 43 |
+
return torch.randn(*mean.shape).to(mean) * torch.exp(self.logstd) + mean
|
| 44 |
+
|
| 45 |
+
class LidarSingleDataModule(pl.LightningDataModule):
|
| 46 |
+
def __init__(self, args):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.save_hyperparameters()
|
| 49 |
+
|
| 50 |
+
self.data_path = args.data_path
|
| 51 |
+
self.batch_size = args.batch_size
|
| 52 |
+
self.max_dim = args.dim
|
| 53 |
+
self.whiten = args.whiten
|
| 54 |
+
self.p0_mu = [
|
| 55 |
+
[-4.5, -4.0, 0.5],
|
| 56 |
+
[-4.2, -3.5, 0.5],
|
| 57 |
+
[-4.0, -3.0, 0.5],
|
| 58 |
+
[-3.75, -2.5, 0.5],
|
| 59 |
+
]
|
| 60 |
+
self.p0_var = 0.02
|
| 61 |
+
# multiple p1 for each branch
|
| 62 |
+
#changed
|
| 63 |
+
self.p1_1_mu = [
|
| 64 |
+
[-2.5, -0.25, 0.5],
|
| 65 |
+
[-2.25, 0.675, 0.5],
|
| 66 |
+
[-2, 1.5, 0.5],
|
| 67 |
+
]
|
| 68 |
+
self.p1_2_mu = [
|
| 69 |
+
[2, -2, 0.5],
|
| 70 |
+
[2.6, -1.25, 0.5],
|
| 71 |
+
[3.2, -0.5, 0.5]
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
self.p1_var = 0.03
|
| 75 |
+
self.k = 20
|
| 76 |
+
self.n_samples = 5000
|
| 77 |
+
self.num_timesteps = 2
|
| 78 |
+
self.split_ratios = args.split_ratios
|
| 79 |
+
|
| 80 |
+
self.num_timesteps = 2
|
| 81 |
+
self.metric_clusters = 3
|
| 82 |
+
self.args = args
|
| 83 |
+
self._prepare_data()
|
| 84 |
+
|
| 85 |
+
def _prepare_data(self):
|
| 86 |
+
las = laspy.read(self.data_path)
|
| 87 |
+
# Extract only "ground" points.
|
| 88 |
+
self.mask = las.classification == 2
|
| 89 |
+
# Original Preprocessing
|
| 90 |
+
x_offset, x_scale = las.header.offsets[0], las.header.scales[0]
|
| 91 |
+
y_offset, y_scale = las.header.offsets[1], las.header.scales[1]
|
| 92 |
+
z_offset, z_scale = las.header.offsets[2], las.header.scales[2]
|
| 93 |
+
dataset = np.vstack(
|
| 94 |
+
(
|
| 95 |
+
las.X[self.mask] * x_scale + x_offset,
|
| 96 |
+
las.Y[self.mask] * y_scale + y_offset,
|
| 97 |
+
las.Z[self.mask] * z_scale + z_offset,
|
| 98 |
+
)
|
| 99 |
+
).transpose()
|
| 100 |
+
mi = dataset.min(axis=0, keepdims=True)
|
| 101 |
+
ma = dataset.max(axis=0, keepdims=True)
|
| 102 |
+
dataset = (dataset - mi) / (ma - mi) * [10.0, 10.0, 2.0] + [-5.0, -5.0, 0.0]
|
| 103 |
+
|
| 104 |
+
self.dataset = torch.tensor(dataset, dtype=torch.float32)
|
| 105 |
+
self.tree = cKDTree(dataset)
|
| 106 |
+
|
| 107 |
+
x0_gaussian = GaussianMM(self.p0_mu, self.p0_var)(self.n_samples)
|
| 108 |
+
x1_1_gaussian = GaussianMM(self.p1_1_mu, self.p1_var)(self.n_samples)
|
| 109 |
+
x1_2_gaussian = GaussianMM(self.p1_2_mu, self.p1_var)(self.n_samples)
|
| 110 |
+
|
| 111 |
+
x0 = self.get_tangent_proj(x0_gaussian)(x0_gaussian)
|
| 112 |
+
x1_1 = self.get_tangent_proj(x1_1_gaussian)(x1_1_gaussian)
|
| 113 |
+
x1_2 = self.get_tangent_proj(x1_2_gaussian)(x1_2_gaussian)
|
| 114 |
+
|
| 115 |
+
split_index = int(self.n_samples * self.split_ratios[0])
|
| 116 |
+
|
| 117 |
+
self.scaler = StandardScaler()
|
| 118 |
+
if self.whiten:
|
| 119 |
+
self.dataset = torch.tensor(
|
| 120 |
+
self.scaler.fit_transform(dataset), dtype=torch.float32
|
| 121 |
+
)
|
| 122 |
+
x0 = torch.tensor(self.scaler.transform(x0), dtype=torch.float32)
|
| 123 |
+
x1_1 = torch.tensor(self.scaler.transform(x1_1), dtype=torch.float32)
|
| 124 |
+
x1_2 = torch.tensor(self.scaler.transform(x1_2), dtype=torch.float32)
|
| 125 |
+
x1 = torch.cat([x1_1, x1_2], dim=0)
|
| 126 |
+
|
| 127 |
+
self.coords_t0 = x0
|
| 128 |
+
self.coords_t1 = x1
|
| 129 |
+
self.time_labels = np.concatenate([
|
| 130 |
+
np.zeros(len(self.coords_t0)), # t=0
|
| 131 |
+
np.ones(len(self.coords_t1)), # t=1
|
| 132 |
+
])
|
| 133 |
+
|
| 134 |
+
train_x0 = x0[:split_index]
|
| 135 |
+
val_x0 = x0[split_index:]
|
| 136 |
+
|
| 137 |
+
# branches
|
| 138 |
+
train_x1 = x1[:split_index]
|
| 139 |
+
val_x1 = x1[split_index:]
|
| 140 |
+
|
| 141 |
+
self.val_x0 = val_x0
|
| 142 |
+
|
| 143 |
+
train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
|
| 144 |
+
train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=1.0)
|
| 145 |
+
|
| 146 |
+
val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
|
| 147 |
+
val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=1.0)
|
| 148 |
+
|
| 149 |
+
# Adjust split_index to ensure minimum validation samples
|
| 150 |
+
if self.n_samples - split_index < self.batch_size:
|
| 151 |
+
split_index = self.n_samples - self.batch_size
|
| 152 |
+
|
| 153 |
+
self.train_dataloaders = {
|
| 154 |
+
"x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 155 |
+
"x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
self.val_dataloaders = {
|
| 159 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 160 |
+
"x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
# to edit?
|
| 164 |
+
self.test_dataloaders = {
|
| 165 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=False),
|
| 166 |
+
"x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=True),
|
| 167 |
+
"dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=True, drop_last=False),
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
points = self.dataset.cpu().numpy()
|
| 171 |
+
x, y = points[:, 0], points[:, 1]
|
| 172 |
+
# Diagonal-based coordinates (rotated 45°)
|
| 173 |
+
u = (x + y) / np.sqrt(2) # along x=y
|
| 174 |
+
# start region (A) using u
|
| 175 |
+
u_thresh = np.percentile(u, 30) # tweak this threshold to control size
|
| 176 |
+
mask_A = u <= u_thresh
|
| 177 |
+
|
| 178 |
+
# among the rest, split by x=y diagonal
|
| 179 |
+
remaining = ~mask_A
|
| 180 |
+
mask_B = remaining & (x < y) # left of diagonal
|
| 181 |
+
mask_C = remaining & (x >= y) # right of diagonal
|
| 182 |
+
|
| 183 |
+
# Assign dataloaders
|
| 184 |
+
self.metric_samples_dataloaders = [
|
| 185 |
+
DataLoader(torch.tensor(points[mask_A], dtype=torch.float32), batch_size=points[mask_A].shape[0], shuffle=False),
|
| 186 |
+
DataLoader(torch.tensor(points[remaining], dtype=torch.float32), batch_size=points[remaining].shape[0], shuffle=False),
|
| 187 |
+
]
|
| 188 |
+
|
| 189 |
+
def train_dataloader(self):
|
| 190 |
+
combined_loaders = {
|
| 191 |
+
"train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
|
| 192 |
+
"metric_samples": CombinedLoader(
|
| 193 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 194 |
+
),
|
| 195 |
+
}
|
| 196 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 197 |
+
|
| 198 |
+
def val_dataloader(self):
|
| 199 |
+
combined_loaders = {
|
| 200 |
+
"val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
|
| 201 |
+
"metric_samples": CombinedLoader(
|
| 202 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 203 |
+
),
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 207 |
+
|
| 208 |
+
def test_dataloader(self):
|
| 209 |
+
combined_loaders = {
|
| 210 |
+
"test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
|
| 211 |
+
"metric_samples": CombinedLoader(
|
| 212 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 213 |
+
),
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 217 |
+
|
| 218 |
+
def get_tangent_proj(self, points):
|
| 219 |
+
w = self.get_tangent_plane(points)
|
| 220 |
+
return partial(LidarSingleDataModule.projection_op, w=w)
|
| 221 |
+
|
| 222 |
+
def get_tangent_plane(self, points, temp=1e-3):
|
| 223 |
+
points_np = points.detach().cpu().numpy()
|
| 224 |
+
_, idx = self.tree.query(points_np, k=self.k)
|
| 225 |
+
nearest_pts = self.dataset[idx]
|
| 226 |
+
nearest_pts = torch.tensor(nearest_pts).to(points)
|
| 227 |
+
|
| 228 |
+
dists = (points.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
|
| 229 |
+
weights = torch.exp(-dists / temp)
|
| 230 |
+
|
| 231 |
+
# Fits plane with least vertical distance.
|
| 232 |
+
w = LidarSingleDataModule.fit_plane(nearest_pts, weights)
|
| 233 |
+
return w
|
| 234 |
+
|
| 235 |
+
@staticmethod
|
| 236 |
+
def fit_plane(points, weights=None):
|
| 237 |
+
"""Expects points to be of shape (..., 3).
|
| 238 |
+
Returns [a, b, c] such that the plane is defined as
|
| 239 |
+
ax + by + c = z
|
| 240 |
+
"""
|
| 241 |
+
D = torch.cat([points[..., :2], torch.ones_like(points[..., 2:3])], dim=-1)
|
| 242 |
+
z = points[..., 2]
|
| 243 |
+
if weights is not None:
|
| 244 |
+
Dtrans = D.transpose(-1, -2)
|
| 245 |
+
else:
|
| 246 |
+
DW = D * weights
|
| 247 |
+
Dtrans = DW.transpose(-1, -2)
|
| 248 |
+
w = torch.linalg.solve(
|
| 249 |
+
torch.matmul(Dtrans, D), torch.matmul(Dtrans, z.unsqueeze(-1))
|
| 250 |
+
).squeeze(-1)
|
| 251 |
+
return w
|
| 252 |
+
|
| 253 |
+
@staticmethod
|
| 254 |
+
def projection_op(x, w):
|
| 255 |
+
"""Projects points to a plane defined by w."""
|
| 256 |
+
# Normal vector to the tangent plane.
|
| 257 |
+
n = torch.cat([w[..., :2], -torch.ones_like(w[..., 2:3])], dim=1)
|
| 258 |
+
|
| 259 |
+
pn = torch.sum(x * n, dim=-1, keepdim=True)
|
| 260 |
+
nn = torch.sum(n * n, dim=-1, keepdim=True)
|
| 261 |
+
|
| 262 |
+
# Offset.
|
| 263 |
+
d = w[..., 2:3]
|
| 264 |
+
|
| 265 |
+
# Projection of x onto n.
|
| 266 |
+
projn_x = ((pn + d) / nn) * n
|
| 267 |
+
|
| 268 |
+
# Remove component in the normal direction.
|
| 269 |
+
return x - projn_x
|
| 270 |
+
|
| 271 |
+
def get_timepoint_data(self):
|
| 272 |
+
"""Return data organized by timepoints for visualization"""
|
| 273 |
+
return {
|
| 274 |
+
't0': self.coords_t0,
|
| 275 |
+
't1': self.coords_t1,
|
| 276 |
+
'time_labels': self.time_labels
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
def get_datamodule():
|
| 280 |
+
datamodule = LidarSingleDataModule(args)
|
| 281 |
+
datamodule.setup(stage="fit")
|
| 282 |
+
return datamodule
|
dataloaders/mouse_data.py
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
sys.argv = ['']
|
| 4 |
+
from sklearn.preprocessing import StandardScaler
|
| 5 |
+
import pytorch_lightning as pl
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from lightning.pytorch.utilities.combined_loader import CombinedLoader
|
| 8 |
+
import numpy as np
|
| 9 |
+
from scipy.spatial import cKDTree
|
| 10 |
+
import math
|
| 11 |
+
from functools import partial
|
| 12 |
+
from sklearn.cluster import KMeans, DBSCAN
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
import pandas as pd
|
| 15 |
+
from torch.utils.data import TensorDataset
|
| 16 |
+
|
| 17 |
+
from train.parsers_sc import parse_args
|
| 18 |
+
args = parse_args()
|
| 19 |
+
|
| 20 |
+
class WeightedBranchedCellDataModule(pl.LightningDataModule):
|
| 21 |
+
def __init__(self, args):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.save_hyperparameters()
|
| 24 |
+
|
| 25 |
+
self.data_path = "./data/mouse_hematopoiesis.csv"
|
| 26 |
+
self.batch_size = args.batch_size
|
| 27 |
+
self.max_dim = args.dim
|
| 28 |
+
self.whiten = args.whiten
|
| 29 |
+
self.k = 20
|
| 30 |
+
self.n_samples = 1429
|
| 31 |
+
self.num_timesteps = 3 # t=0, t=1, t=2
|
| 32 |
+
self.split_ratios = args.split_ratios
|
| 33 |
+
self.metric_clusters = args.metric_clusters
|
| 34 |
+
self.args = args
|
| 35 |
+
self._prepare_data()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _prepare_data(self):
|
| 39 |
+
print("Preparing cell data in BranchedCellDataModule")
|
| 40 |
+
|
| 41 |
+
df = pd.read_csv(self.data_path)
|
| 42 |
+
|
| 43 |
+
# Build dictionary of coordinates by time
|
| 44 |
+
coords_by_t = {
|
| 45 |
+
t: df[df["samples"] == t][["x1","x2"]].values
|
| 46 |
+
for t in sorted(df["samples"].unique())
|
| 47 |
+
}
|
| 48 |
+
n0 = coords_by_t[0].shape[0] # Number of T=0 points
|
| 49 |
+
self.n_samples = n0 # Update n_samples to match actual data if changes
|
| 50 |
+
|
| 51 |
+
# Cluster the t=2 cells into two branches
|
| 52 |
+
km = KMeans(n_clusters=2, random_state=42).fit(coords_by_t[2])
|
| 53 |
+
df2 = df[df["samples"] == 2].copy()
|
| 54 |
+
df2["branch"] = km.labels_
|
| 55 |
+
|
| 56 |
+
cluster_counts = df2["branch"].value_counts().sort_index()
|
| 57 |
+
print(cluster_counts)
|
| 58 |
+
|
| 59 |
+
# Sample n0 points from each branch
|
| 60 |
+
endpoints = {}
|
| 61 |
+
for b in (0, 1):
|
| 62 |
+
endpoints[b] = (
|
| 63 |
+
df2[df2["branch"] == b]
|
| 64 |
+
.sample(n=n0, random_state=42)[["x1","x2"]]
|
| 65 |
+
.values
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
x0 = torch.tensor(coords_by_t[0], dtype=torch.float32) # T=0 coordinates index
|
| 69 |
+
x_inter = torch.tensor(coords_by_t[1], dtype=torch.float32)
|
| 70 |
+
x1_1 = torch.tensor(endpoints[0], dtype=torch.float32) # Branch index
|
| 71 |
+
x1_2 = torch.tensor(endpoints[1], dtype=torch.float32) # Branch index
|
| 72 |
+
|
| 73 |
+
self.coords_t0 = x0
|
| 74 |
+
self.coords_t1 = x_inter
|
| 75 |
+
self.coords_t2_1 = x1_1
|
| 76 |
+
self.coords_t2_2 = x1_2
|
| 77 |
+
self.time_labels = np.concatenate([
|
| 78 |
+
np.zeros(len(self.coords_t0)), # t=0
|
| 79 |
+
np.ones(len(self.coords_t1)), # t=1
|
| 80 |
+
np.ones(len(self.coords_t2_1)) * 2, # t=1
|
| 81 |
+
np.ones(len(self.coords_t2_2)) * 2,
|
| 82 |
+
])
|
| 83 |
+
|
| 84 |
+
split_index = int(n0 * self.split_ratios[0])
|
| 85 |
+
|
| 86 |
+
if n0 - split_index < self.batch_size:
|
| 87 |
+
split_index = n0 - self.batch_size
|
| 88 |
+
|
| 89 |
+
train_x0 = x0[:split_index]
|
| 90 |
+
val_x0 = x0[split_index:]
|
| 91 |
+
train_x1_1 = x1_1[:split_index]
|
| 92 |
+
val_x1_1 = x1_1[split_index:]
|
| 93 |
+
train_x1_2 = x1_2[:split_index]
|
| 94 |
+
val_x1_2 = x1_2[split_index:]
|
| 95 |
+
|
| 96 |
+
self.val_x0 = val_x0
|
| 97 |
+
|
| 98 |
+
train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
|
| 99 |
+
train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.5)
|
| 100 |
+
train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.5)
|
| 101 |
+
|
| 102 |
+
val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
|
| 103 |
+
val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.5)
|
| 104 |
+
val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.5)
|
| 105 |
+
|
| 106 |
+
if self.n_samples - split_index < self.batch_size:
|
| 107 |
+
split_index = self.n_samples - self.batch_size
|
| 108 |
+
|
| 109 |
+
self.train_dataloaders = {
|
| 110 |
+
"x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 111 |
+
"x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 112 |
+
"x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
self.val_dataloaders = {
|
| 116 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
|
| 117 |
+
"x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 118 |
+
"x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
all_data = np.vstack([coords_by_t[t] for t in sorted(coords_by_t.keys())])
|
| 122 |
+
self.dataset = torch.tensor(all_data, dtype=torch.float32)
|
| 123 |
+
self.tree = cKDTree(all_data)
|
| 124 |
+
|
| 125 |
+
# if whitening is enabled, need to apply this to the full dataset
|
| 126 |
+
#if self.whiten:
|
| 127 |
+
#self.scaler = StandardScaler()
|
| 128 |
+
#self.dataset = torch.tensor(
|
| 129 |
+
#self.scaler.fit_transform(all_data), dtype=torch.float32
|
| 130 |
+
#)
|
| 131 |
+
|
| 132 |
+
self.test_dataloaders = {
|
| 133 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
|
| 134 |
+
"dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
# Metric Dataloader
|
| 138 |
+
# K-means clustering of ALL points into 2 groups
|
| 139 |
+
if self.metric_clusters == 3:
|
| 140 |
+
km_all = KMeans(n_clusters=3, random_state=45).fit(self.dataset.numpy())
|
| 141 |
+
cluster_labels = km_all.labels_
|
| 142 |
+
|
| 143 |
+
cluster_0_mask = cluster_labels == 0
|
| 144 |
+
cluster_1_mask = cluster_labels == 1
|
| 145 |
+
cluster_2_mask = cluster_labels == 2
|
| 146 |
+
|
| 147 |
+
samples = self.dataset.cpu().numpy()
|
| 148 |
+
|
| 149 |
+
cluster_0_data = samples[cluster_0_mask]
|
| 150 |
+
cluster_1_data = samples[cluster_1_mask]
|
| 151 |
+
cluster_2_data = samples[cluster_2_mask]
|
| 152 |
+
|
| 153 |
+
self.metric_samples_dataloaders = [
|
| 154 |
+
DataLoader(
|
| 155 |
+
torch.tensor(cluster_1_data, dtype=torch.float32),
|
| 156 |
+
batch_size=cluster_1_data.shape[0],
|
| 157 |
+
shuffle=False,
|
| 158 |
+
drop_last=False,
|
| 159 |
+
),
|
| 160 |
+
DataLoader(
|
| 161 |
+
torch.tensor(cluster_2_data, dtype=torch.float32),
|
| 162 |
+
batch_size=cluster_2_data.shape[0],
|
| 163 |
+
shuffle=False,
|
| 164 |
+
drop_last=False,
|
| 165 |
+
),
|
| 166 |
+
|
| 167 |
+
DataLoader(
|
| 168 |
+
torch.tensor(cluster_0_data, dtype=torch.float32),
|
| 169 |
+
batch_size=cluster_0_data.shape[0],
|
| 170 |
+
shuffle=False,
|
| 171 |
+
drop_last=False,
|
| 172 |
+
),
|
| 173 |
+
]
|
| 174 |
+
else:
|
| 175 |
+
km_all = KMeans(n_clusters=2, random_state=45).fit(self.dataset.numpy())
|
| 176 |
+
cluster_labels = km_all.labels_
|
| 177 |
+
|
| 178 |
+
cluster_0_mask = cluster_labels == 0
|
| 179 |
+
cluster_1_mask = cluster_labels == 1
|
| 180 |
+
|
| 181 |
+
samples = self.dataset.cpu().numpy()
|
| 182 |
+
|
| 183 |
+
cluster_0_data = samples[cluster_0_mask]
|
| 184 |
+
cluster_1_data = samples[cluster_1_mask]
|
| 185 |
+
|
| 186 |
+
self.metric_samples_dataloaders = [
|
| 187 |
+
DataLoader(
|
| 188 |
+
torch.tensor(cluster_1_data, dtype=torch.float32),
|
| 189 |
+
batch_size=cluster_1_data.shape[0],
|
| 190 |
+
shuffle=False,
|
| 191 |
+
drop_last=False,
|
| 192 |
+
),
|
| 193 |
+
DataLoader(
|
| 194 |
+
torch.tensor(cluster_0_data, dtype=torch.float32),
|
| 195 |
+
batch_size=cluster_0_data.shape[0],
|
| 196 |
+
shuffle=False,
|
| 197 |
+
drop_last=False,
|
| 198 |
+
),
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def train_dataloader(self):
|
| 203 |
+
combined_loaders = {
|
| 204 |
+
"train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
|
| 205 |
+
"metric_samples": CombinedLoader(
|
| 206 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 207 |
+
),
|
| 208 |
+
}
|
| 209 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 210 |
+
|
| 211 |
+
def val_dataloader(self):
|
| 212 |
+
combined_loaders = {
|
| 213 |
+
"val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
|
| 214 |
+
"metric_samples": CombinedLoader(
|
| 215 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 216 |
+
),
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 220 |
+
|
| 221 |
+
def test_dataloader(self):
|
| 222 |
+
combined_loaders = {
|
| 223 |
+
"test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
|
| 224 |
+
"metric_samples": CombinedLoader(
|
| 225 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 226 |
+
),
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 230 |
+
|
| 231 |
+
def get_manifold_proj(self, points):
|
| 232 |
+
"""Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
|
| 233 |
+
return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
|
| 234 |
+
|
| 235 |
+
@staticmethod
|
| 236 |
+
def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
|
| 237 |
+
"""
|
| 238 |
+
Apply local smoothing based on k-nearest neighbors in the full dataset
|
| 239 |
+
This replaces the plane projection for 2D manifold regularization
|
| 240 |
+
"""
|
| 241 |
+
points_np = x.detach().cpu().numpy()
|
| 242 |
+
_, idx = tree.query(points_np, k=k)
|
| 243 |
+
nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
|
| 244 |
+
|
| 245 |
+
# Compute weighted average of neighbors
|
| 246 |
+
dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
|
| 247 |
+
weights = torch.exp(-dists / temp)
|
| 248 |
+
weights = weights / weights.sum(dim=1, keepdim=True)
|
| 249 |
+
|
| 250 |
+
# Weighted average of neighbors
|
| 251 |
+
smoothed = (weights * nearest_pts).sum(dim=1)
|
| 252 |
+
|
| 253 |
+
# Blend original point with smoothed version
|
| 254 |
+
alpha = 0.3 # How much smoothing to apply
|
| 255 |
+
return (1 - alpha) * x + alpha * smoothed
|
| 256 |
+
|
| 257 |
+
def get_timepoint_data(self):
|
| 258 |
+
"""Return data organized by timepoints for visualization"""
|
| 259 |
+
return {
|
| 260 |
+
't0': self.coords_t0,
|
| 261 |
+
't1': self.coords_t1,
|
| 262 |
+
't2_1': self.coords_t2_1,
|
| 263 |
+
't2_2': self.coords_t2_2,
|
| 264 |
+
'time_labels': self.time_labels
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class SingleBranchCellDataModule(pl.LightningDataModule):
|
| 270 |
+
def __init__(self, args):
|
| 271 |
+
super().__init__()
|
| 272 |
+
self.save_hyperparameters()
|
| 273 |
+
|
| 274 |
+
self.data_path = "./data/mouse_hematopoiesis.csv"
|
| 275 |
+
self.batch_size = args.batch_size
|
| 276 |
+
self.max_dim = args.dim
|
| 277 |
+
self.whiten = args.whiten
|
| 278 |
+
self.k = 20
|
| 279 |
+
self.n_samples = 1429
|
| 280 |
+
self.num_timesteps = 3 # t=0, t=1, t=2
|
| 281 |
+
self.split_ratios = args.split_ratios
|
| 282 |
+
self.metric_clusters = 3
|
| 283 |
+
self.args = args
|
| 284 |
+
self._prepare_data()
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def _prepare_data(self):
|
| 288 |
+
print("Preparing cell data in BranchedCellDataModule")
|
| 289 |
+
|
| 290 |
+
df = pd.read_csv(self.data_path)
|
| 291 |
+
|
| 292 |
+
# Build dictionary of coordinates by time
|
| 293 |
+
coords_by_t = {
|
| 294 |
+
t: df[df["samples"] == t][["x1","x2"]].values
|
| 295 |
+
for t in sorted(df["samples"].unique())
|
| 296 |
+
}
|
| 297 |
+
n0 = coords_by_t[0].shape[0] # Number of T=0 points
|
| 298 |
+
self.n_samples = n0 # Update n_samples to match actual data if changes
|
| 299 |
+
|
| 300 |
+
x0 = torch.tensor(coords_by_t[0], dtype=torch.float32) # T=0 coordinates index
|
| 301 |
+
x_inter = torch.tensor(coords_by_t[1], dtype=torch.float32)
|
| 302 |
+
x1 = torch.tensor(coords_by_t[2], dtype=torch.float32) # Branch index
|
| 303 |
+
|
| 304 |
+
split_index = int(n0 * self.split_ratios[0])
|
| 305 |
+
|
| 306 |
+
if n0 - split_index < self.batch_size:
|
| 307 |
+
split_index = n0 - self.batch_size
|
| 308 |
+
|
| 309 |
+
train_x0 = x0[:split_index]
|
| 310 |
+
val_x0 = x0[split_index:]
|
| 311 |
+
train_x1 = x1[:split_index]
|
| 312 |
+
val_x1 = x1[split_index:]
|
| 313 |
+
|
| 314 |
+
self.val_x0 = val_x0
|
| 315 |
+
|
| 316 |
+
train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
|
| 317 |
+
train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=0.5)
|
| 318 |
+
|
| 319 |
+
val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
|
| 320 |
+
val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=0.5)
|
| 321 |
+
|
| 322 |
+
if self.n_samples - split_index < self.batch_size:
|
| 323 |
+
split_index = self.n_samples - self.batch_size
|
| 324 |
+
|
| 325 |
+
self.train_dataloaders = {
|
| 326 |
+
"x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 327 |
+
"x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
self.val_dataloaders = {
|
| 331 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
|
| 332 |
+
"x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
all_data = np.vstack([coords_by_t[t] for t in sorted(coords_by_t.keys())])
|
| 336 |
+
self.dataset = torch.tensor(all_data, dtype=torch.float32)
|
| 337 |
+
self.tree = cKDTree(all_data)
|
| 338 |
+
|
| 339 |
+
# if whitening is enabled, need to apply this to the full dataset
|
| 340 |
+
if self.whiten:
|
| 341 |
+
self.scaler = StandardScaler()
|
| 342 |
+
self.dataset = torch.tensor(
|
| 343 |
+
self.scaler.fit_transform(all_data), dtype=torch.float32
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
self.test_dataloaders = {
|
| 347 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
|
| 348 |
+
"dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
# Metric Dataloader
|
| 352 |
+
# K-means clustering of ALL points into 2 groups
|
| 353 |
+
km_all = KMeans(n_clusters=2, random_state=45).fit(self.dataset.numpy())
|
| 354 |
+
cluster_labels = km_all.labels_
|
| 355 |
+
|
| 356 |
+
cluster_0_mask = cluster_labels == 0
|
| 357 |
+
cluster_1_mask = cluster_labels == 1
|
| 358 |
+
|
| 359 |
+
samples = self.dataset.cpu().numpy()
|
| 360 |
+
|
| 361 |
+
cluster_0_data = samples[cluster_0_mask]
|
| 362 |
+
cluster_1_data = samples[cluster_1_mask]
|
| 363 |
+
|
| 364 |
+
self.metric_samples_dataloaders = [
|
| 365 |
+
DataLoader(
|
| 366 |
+
torch.tensor(cluster_1_data, dtype=torch.float32),
|
| 367 |
+
batch_size=cluster_1_data.shape[0],
|
| 368 |
+
shuffle=False,
|
| 369 |
+
drop_last=False,
|
| 370 |
+
),
|
| 371 |
+
DataLoader(
|
| 372 |
+
torch.tensor(cluster_0_data, dtype=torch.float32),
|
| 373 |
+
batch_size=cluster_0_data.shape[0],
|
| 374 |
+
shuffle=False,
|
| 375 |
+
drop_last=False,
|
| 376 |
+
),
|
| 377 |
+
]
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def train_dataloader(self):
|
| 381 |
+
combined_loaders = {
|
| 382 |
+
"train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
|
| 383 |
+
"metric_samples": CombinedLoader(
|
| 384 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 385 |
+
),
|
| 386 |
+
}
|
| 387 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 388 |
+
|
| 389 |
+
def val_dataloader(self):
|
| 390 |
+
combined_loaders = {
|
| 391 |
+
"val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
|
| 392 |
+
"metric_samples": CombinedLoader(
|
| 393 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 394 |
+
),
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 398 |
+
|
| 399 |
+
def test_dataloader(self):
|
| 400 |
+
combined_loaders = {
|
| 401 |
+
"test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
|
| 402 |
+
"metric_samples": CombinedLoader(
|
| 403 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 404 |
+
),
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 408 |
+
|
| 409 |
+
def get_manifold_proj(self, points):
|
| 410 |
+
"""Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
|
| 411 |
+
return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
|
| 412 |
+
|
| 413 |
+
@staticmethod
|
| 414 |
+
def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
|
| 415 |
+
"""
|
| 416 |
+
Apply local smoothing based on k-nearest neighbors in the full dataset
|
| 417 |
+
This replaces the plane projection for 2D manifold regularization
|
| 418 |
+
"""
|
| 419 |
+
points_np = x.detach().cpu().numpy()
|
| 420 |
+
_, idx = tree.query(points_np, k=k)
|
| 421 |
+
nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
|
| 422 |
+
|
| 423 |
+
# Compute weighted average of neighbors
|
| 424 |
+
dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
|
| 425 |
+
weights = torch.exp(-dists / temp)
|
| 426 |
+
weights = weights / weights.sum(dim=1, keepdim=True)
|
| 427 |
+
|
| 428 |
+
# Weighted average of neighbors
|
| 429 |
+
smoothed = (weights * nearest_pts).sum(dim=1)
|
| 430 |
+
|
| 431 |
+
# Blend original point with smoothed version
|
| 432 |
+
alpha = 0.3 # How much smoothing to apply
|
| 433 |
+
return (1 - alpha) * x + alpha * smoothed
|
| 434 |
+
|
| 435 |
+
def get_datamodule():
|
| 436 |
+
datamodule = WeightedBranchedCellDataModule(args)
|
| 437 |
+
datamodule.setup(stage="fit")
|
| 438 |
+
return datamodule
|
dataloaders/three_branch_data.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
sys.argv = ['']
|
| 4 |
+
from sklearn.preprocessing import StandardScaler
|
| 5 |
+
import pytorch_lightning as pl
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from lightning.pytorch.utilities.combined_loader import CombinedLoader
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
from functools import partial
|
| 11 |
+
from scipy.spatial import cKDTree
|
| 12 |
+
from sklearn.cluster import KMeans
|
| 13 |
+
from torch.utils.data import TensorDataset
|
| 14 |
+
|
| 15 |
+
from train.parsers_tahoe import parse_args
|
| 16 |
+
args = parse_args()
|
| 17 |
+
|
| 18 |
+
class ThreeBranchTahoeDataModule(pl.LightningDataModule):
|
| 19 |
+
def __init__(self, args):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.save_hyperparameters()
|
| 22 |
+
|
| 23 |
+
self.batch_size = args.batch_size
|
| 24 |
+
self.max_dim = args.dim
|
| 25 |
+
self.whiten = args.whiten
|
| 26 |
+
self.split_ratios = args.split_ratios
|
| 27 |
+
self.num_timesteps = 2
|
| 28 |
+
self.data_path = "./data/Trametinib_5.0uM_pca_and_leidenumap_labels.csv"
|
| 29 |
+
self.args = args
|
| 30 |
+
|
| 31 |
+
self._prepare_data()
|
| 32 |
+
|
| 33 |
+
def _prepare_data(self):
|
| 34 |
+
df = pd.read_csv(self.data_path, comment='#')
|
| 35 |
+
df = df.iloc[:, 1:]
|
| 36 |
+
df = df.replace('', np.nan)
|
| 37 |
+
pc_cols = df.columns[:50]
|
| 38 |
+
for col in pc_cols:
|
| 39 |
+
df[col] = pd.to_numeric(df[col], errors='coerce')
|
| 40 |
+
leiden_dmso_col = 'leiden_DMSO_TF_0.0uM'
|
| 41 |
+
leiden_clonidine_col = 'leiden_Trametinib_5.0uM'
|
| 42 |
+
|
| 43 |
+
dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column
|
| 44 |
+
clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column
|
| 45 |
+
|
| 46 |
+
dmso_data = df[dmso_mask].copy()
|
| 47 |
+
clonidine_data = df[clonidine_mask].copy()
|
| 48 |
+
|
| 49 |
+
# Updated to include all three clusters: 0, 4, and 6
|
| 50 |
+
top_clonidine_clusters = ['1.0', '3.0', '5.0']
|
| 51 |
+
|
| 52 |
+
x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]]
|
| 53 |
+
x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]]
|
| 54 |
+
x1_3_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[2]]
|
| 55 |
+
|
| 56 |
+
x1_1_coords = x1_1_data[pc_cols].values
|
| 57 |
+
x1_2_coords = x1_2_data[pc_cols].values
|
| 58 |
+
x1_3_coords = x1_3_data[pc_cols].values
|
| 59 |
+
|
| 60 |
+
x1_1_coords = x1_1_coords.astype(float)
|
| 61 |
+
x1_2_coords = x1_2_coords.astype(float)
|
| 62 |
+
x1_3_coords = x1_3_coords.astype(float)
|
| 63 |
+
|
| 64 |
+
# Target size is now the minimum across all three endpoint clusters
|
| 65 |
+
target_size = min(len(x1_1_coords), len(x1_2_coords), len(x1_3_coords))
|
| 66 |
+
|
| 67 |
+
# Helper function to select points closest to centroid
|
| 68 |
+
def select_closest_to_centroid(coords, target_size):
|
| 69 |
+
if len(coords) <= target_size:
|
| 70 |
+
return coords
|
| 71 |
+
|
| 72 |
+
# Calculate centroid
|
| 73 |
+
centroid = np.mean(coords, axis=0)
|
| 74 |
+
|
| 75 |
+
# Calculate distances to centroid
|
| 76 |
+
distances = np.linalg.norm(coords - centroid, axis=1)
|
| 77 |
+
|
| 78 |
+
# Get indices of closest points
|
| 79 |
+
closest_indices = np.argsort(distances)[:target_size]
|
| 80 |
+
|
| 81 |
+
return coords[closest_indices]
|
| 82 |
+
|
| 83 |
+
# Sample all endpoint clusters to target size using centroid-based selection
|
| 84 |
+
x1_1_coords = select_closest_to_centroid(x1_1_coords, target_size)
|
| 85 |
+
x1_2_coords = select_closest_to_centroid(x1_2_coords, target_size)
|
| 86 |
+
x1_3_coords = select_closest_to_centroid(x1_3_coords, target_size)
|
| 87 |
+
|
| 88 |
+
dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts()
|
| 89 |
+
|
| 90 |
+
# DMSO (unchanged)
|
| 91 |
+
largest_dmso_cluster = dmso_cluster_counts.index[0]
|
| 92 |
+
dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster]
|
| 93 |
+
|
| 94 |
+
dmso_coords = dmso_cluster_data[pc_cols].values
|
| 95 |
+
|
| 96 |
+
# Random sampling from largest DMSO cluster to match target size
|
| 97 |
+
# For DMSO, we'll also use centroid-based selection for consistency
|
| 98 |
+
if len(dmso_coords) >= target_size:
|
| 99 |
+
x0_coords = select_closest_to_centroid(dmso_coords, target_size)
|
| 100 |
+
else:
|
| 101 |
+
# If largest cluster is smaller than target, use all of it and pad with other DMSO cells
|
| 102 |
+
remaining_needed = target_size - len(dmso_coords)
|
| 103 |
+
other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster]
|
| 104 |
+
other_dmso_coords = other_dmso_data[pc_cols].values
|
| 105 |
+
|
| 106 |
+
if len(other_dmso_coords) >= remaining_needed:
|
| 107 |
+
# Select closest to centroid from other DMSO cells
|
| 108 |
+
other_selected = select_closest_to_centroid(other_dmso_coords, remaining_needed)
|
| 109 |
+
x0_coords = np.vstack([dmso_coords, other_selected])
|
| 110 |
+
else:
|
| 111 |
+
# Use all available DMSO cells and reduce target size
|
| 112 |
+
all_dmso_coords = dmso_data[pc_cols].values
|
| 113 |
+
target_size = min(target_size, len(all_dmso_coords))
|
| 114 |
+
x0_coords = select_closest_to_centroid(all_dmso_coords, target_size)
|
| 115 |
+
|
| 116 |
+
# Re-select endpoint clusters with updated target size
|
| 117 |
+
x1_1_coords = select_closest_to_centroid(x1_1_data[pc_cols].values.astype(float), target_size)
|
| 118 |
+
x1_2_coords = select_closest_to_centroid(x1_2_data[pc_cols].values.astype(float), target_size)
|
| 119 |
+
x1_3_coords = select_closest_to_centroid(x1_3_data[pc_cols].values.astype(float), target_size)
|
| 120 |
+
|
| 121 |
+
# No need to resample since we already selected the right number
|
| 122 |
+
# The endpoint clusters are already at target_size from centroid-based selection
|
| 123 |
+
|
| 124 |
+
self.n_samples = target_size
|
| 125 |
+
|
| 126 |
+
# for plotting
|
| 127 |
+
self.coords_t0 = torch.tensor(x0_coords, dtype=torch.float32)
|
| 128 |
+
self.coords_t1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
|
| 129 |
+
self.coords_t1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
|
| 130 |
+
self.coords_t1_3 = torch.tensor(x1_3_coords, dtype=torch.float32)
|
| 131 |
+
|
| 132 |
+
self.time_labels = np.concatenate([
|
| 133 |
+
np.zeros(len(self.coords_t0)), # t=0
|
| 134 |
+
np.ones(len(self.coords_t1_1)), # t=1
|
| 135 |
+
np.ones(len(self.coords_t1_2)), # t=1
|
| 136 |
+
np.ones(len(self.coords_t1_3)), # t=1
|
| 137 |
+
])
|
| 138 |
+
|
| 139 |
+
x0 = torch.tensor(x0_coords, dtype=torch.float32)
|
| 140 |
+
x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
|
| 141 |
+
x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
|
| 142 |
+
x1_3 = torch.tensor(x1_3_coords, dtype=torch.float32)
|
| 143 |
+
|
| 144 |
+
split_index = int(target_size * self.split_ratios[0])
|
| 145 |
+
|
| 146 |
+
if target_size - split_index < self.batch_size:
|
| 147 |
+
split_index = target_size - self.batch_size
|
| 148 |
+
|
| 149 |
+
train_x0 = x0[:split_index]
|
| 150 |
+
val_x0 = x0[split_index:]
|
| 151 |
+
train_x1_1 = x1_1[:split_index]
|
| 152 |
+
val_x1_1 = x1_1[split_index:]
|
| 153 |
+
train_x1_2 = x1_2[:split_index]
|
| 154 |
+
val_x1_2 = x1_2[split_index:]
|
| 155 |
+
train_x1_3 = x1_3[:split_index]
|
| 156 |
+
val_x1_3 = x1_3[split_index:]
|
| 157 |
+
|
| 158 |
+
self.val_x0 = val_x0
|
| 159 |
+
|
| 160 |
+
train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
|
| 161 |
+
train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.603)
|
| 162 |
+
train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.255)
|
| 163 |
+
train_x1_3_weights = torch.full((train_x1_3.shape[0], 1), fill_value=0.142)
|
| 164 |
+
|
| 165 |
+
val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
|
| 166 |
+
val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.603)
|
| 167 |
+
val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.255)
|
| 168 |
+
val_x1_3_weights = torch.full((val_x1_3.shape[0], 1), fill_value=0.142)
|
| 169 |
+
|
| 170 |
+
# Updated train dataloaders to include x1_3
|
| 171 |
+
self.train_dataloaders = {
|
| 172 |
+
"x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 173 |
+
"x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 174 |
+
"x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 175 |
+
"x1_3": DataLoader(TensorDataset(train_x1_3, train_x1_3_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
# Updated val dataloaders to include x1_3
|
| 179 |
+
self.val_dataloaders = {
|
| 180 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
|
| 181 |
+
"x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 182 |
+
"x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 183 |
+
"x1_3": DataLoader(TensorDataset(val_x1_3, val_x1_3_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
all_coords = df[pc_cols].dropna().values.astype(float)
|
| 187 |
+
self.dataset = torch.tensor(all_coords, dtype=torch.float32)
|
| 188 |
+
self.tree = cKDTree(all_coords)
|
| 189 |
+
|
| 190 |
+
self.test_dataloaders = {
|
| 191 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
|
| 192 |
+
"dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
# Updated metric samples - now using 4 clusters instead of 3
|
| 196 |
+
#km_all = KMeans(n_clusters=4, random_state=42).fit(self.dataset.numpy())
|
| 197 |
+
km_all = KMeans(n_clusters=4, random_state=0).fit(self.dataset[:, :3].numpy())
|
| 198 |
+
|
| 199 |
+
cluster_labels = km_all.labels_
|
| 200 |
+
|
| 201 |
+
cluster_0_mask = cluster_labels == 0
|
| 202 |
+
cluster_1_mask = cluster_labels == 1
|
| 203 |
+
cluster_2_mask = cluster_labels == 2
|
| 204 |
+
cluster_3_mask = cluster_labels == 3
|
| 205 |
+
|
| 206 |
+
samples = self.dataset.cpu().numpy()
|
| 207 |
+
|
| 208 |
+
cluster_0_data = samples[cluster_0_mask]
|
| 209 |
+
cluster_1_data = samples[cluster_1_mask]
|
| 210 |
+
cluster_2_data = samples[cluster_2_mask]
|
| 211 |
+
cluster_3_data = samples[cluster_3_mask]
|
| 212 |
+
|
| 213 |
+
self.metric_samples_dataloaders = [
|
| 214 |
+
DataLoader(
|
| 215 |
+
torch.tensor(cluster_1_data, dtype=torch.float32),
|
| 216 |
+
batch_size=cluster_1_data.shape[0],
|
| 217 |
+
shuffle=False,
|
| 218 |
+
drop_last=False,
|
| 219 |
+
),
|
| 220 |
+
DataLoader(
|
| 221 |
+
torch.tensor(cluster_3_data, dtype=torch.float32),
|
| 222 |
+
batch_size=cluster_3_data.shape[0],
|
| 223 |
+
shuffle=False,
|
| 224 |
+
drop_last=False,
|
| 225 |
+
),
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
DataLoader(
|
| 229 |
+
torch.tensor(cluster_2_data, dtype=torch.float32),
|
| 230 |
+
batch_size=cluster_2_data.shape[0],
|
| 231 |
+
shuffle=False,
|
| 232 |
+
drop_last=False,
|
| 233 |
+
),
|
| 234 |
+
DataLoader(
|
| 235 |
+
torch.tensor(cluster_0_data, dtype=torch.float32),
|
| 236 |
+
batch_size=cluster_0_data.shape[0],
|
| 237 |
+
shuffle=False,
|
| 238 |
+
drop_last=False,
|
| 239 |
+
),
|
| 240 |
+
]
|
| 241 |
+
|
| 242 |
+
def train_dataloader(self):
|
| 243 |
+
combined_loaders = {
|
| 244 |
+
"train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
|
| 245 |
+
"metric_samples": CombinedLoader(
|
| 246 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 247 |
+
),
|
| 248 |
+
}
|
| 249 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 250 |
+
|
| 251 |
+
def val_dataloader(self):
|
| 252 |
+
combined_loaders = {
|
| 253 |
+
"val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
|
| 254 |
+
"metric_samples": CombinedLoader(
|
| 255 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 256 |
+
),
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 260 |
+
|
| 261 |
+
def test_dataloader(self):
|
| 262 |
+
combined_loaders = {
|
| 263 |
+
"test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
|
| 264 |
+
"metric_samples": CombinedLoader(
|
| 265 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 266 |
+
),
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 270 |
+
|
| 271 |
+
def get_manifold_proj(self, points):
|
| 272 |
+
"""Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
|
| 273 |
+
return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
|
| 274 |
+
|
| 275 |
+
@staticmethod
|
| 276 |
+
def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
|
| 277 |
+
"""
|
| 278 |
+
Apply local smoothing based on k-nearest neighbors in the full dataset
|
| 279 |
+
This replaces the plane projection for 2D manifold regularization
|
| 280 |
+
"""
|
| 281 |
+
points_np = x.detach().cpu().numpy()
|
| 282 |
+
_, idx = tree.query(points_np, k=k)
|
| 283 |
+
nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
|
| 284 |
+
|
| 285 |
+
# Compute weighted average of neighbors
|
| 286 |
+
dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
|
| 287 |
+
weights = torch.exp(-dists / temp)
|
| 288 |
+
weights = weights / weights.sum(dim=1, keepdim=True)
|
| 289 |
+
|
| 290 |
+
# Weighted average of neighbors
|
| 291 |
+
smoothed = (weights * nearest_pts).sum(dim=1)
|
| 292 |
+
|
| 293 |
+
# Blend original point with smoothed version
|
| 294 |
+
alpha = 0.3 # How much smoothing to apply
|
| 295 |
+
return (1 - alpha) * x + alpha * smoothed
|
| 296 |
+
|
| 297 |
+
def get_timepoint_data(self):
|
| 298 |
+
"""Return data organized by timepoints for visualization"""
|
| 299 |
+
return {
|
| 300 |
+
't0': self.coords_t0,
|
| 301 |
+
't1_1': self.coords_t1_1,
|
| 302 |
+
't1_2': self.coords_t1_2,
|
| 303 |
+
't1_3': self.coords_t1_3,
|
| 304 |
+
'time_labels': self.time_labels
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
def get_datamodule():
|
| 308 |
+
datamodule = ThreeBranchTahoeDataModule(args)
|
| 309 |
+
datamodule.setup(stage="fit")
|
| 310 |
+
return datamodule
|
dataloaders/trametinib_single.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
sys.argv = ['']
|
| 4 |
+
from sklearn.preprocessing import StandardScaler
|
| 5 |
+
import pytorch_lightning as pl
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from lightning.pytorch.utilities.combined_loader import CombinedLoader
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
from functools import partial
|
| 11 |
+
from scipy.spatial import cKDTree
|
| 12 |
+
from sklearn.cluster import KMeans
|
| 13 |
+
from torch.utils.data import TensorDataset
|
| 14 |
+
|
| 15 |
+
from train.parsers_tahoe import parse_args
|
| 16 |
+
args = parse_args()
|
| 17 |
+
|
| 18 |
+
class TrametinibSingleBranchDataModule(pl.LightningDataModule):
|
| 19 |
+
def __init__(self, args):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.save_hyperparameters()
|
| 22 |
+
|
| 23 |
+
self.batch_size = args.batch_size
|
| 24 |
+
self.max_dim = args.dim
|
| 25 |
+
self.whiten = args.whiten
|
| 26 |
+
self.split_ratios = args.split_ratios
|
| 27 |
+
self.num_timesteps = 2
|
| 28 |
+
self.data_path = "./data/Trametinib_5.0uM_pca_and_leidenumap_labels.csv"
|
| 29 |
+
self.args = args
|
| 30 |
+
|
| 31 |
+
self._prepare_data()
|
| 32 |
+
|
| 33 |
+
def _prepare_data(self):
|
| 34 |
+
df = pd.read_csv(self.data_path, comment='#')
|
| 35 |
+
df = df.iloc[:, 1:]
|
| 36 |
+
df = df.replace('', np.nan)
|
| 37 |
+
pc_cols = df.columns[:50]
|
| 38 |
+
for col in pc_cols:
|
| 39 |
+
df[col] = pd.to_numeric(df[col], errors='coerce')
|
| 40 |
+
leiden_dmso_col = 'leiden_DMSO_TF_0.0uM'
|
| 41 |
+
leiden_clonidine_col = 'leiden_Trametinib_5.0uM'
|
| 42 |
+
|
| 43 |
+
dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column
|
| 44 |
+
clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column
|
| 45 |
+
|
| 46 |
+
dmso_data = df[dmso_mask].copy()
|
| 47 |
+
clonidine_data = df[clonidine_mask].copy()
|
| 48 |
+
|
| 49 |
+
# Updated to include all three clusters: 0, 4, and 6
|
| 50 |
+
top_clonidine_clusters = ['1.0', '3.0', '5.0']
|
| 51 |
+
|
| 52 |
+
x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]]
|
| 53 |
+
x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]]
|
| 54 |
+
x1_3_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[2]]
|
| 55 |
+
|
| 56 |
+
x1_1_coords = x1_1_data[pc_cols].values
|
| 57 |
+
x1_2_coords = x1_2_data[pc_cols].values
|
| 58 |
+
x1_3_coords = x1_3_data[pc_cols].values
|
| 59 |
+
|
| 60 |
+
x1_1_coords = x1_1_coords.astype(float)
|
| 61 |
+
x1_2_coords = x1_2_coords.astype(float)
|
| 62 |
+
x1_3_coords = x1_3_coords.astype(float)
|
| 63 |
+
|
| 64 |
+
# Target size is now the minimum across all three endpoint clusters
|
| 65 |
+
target_size = min(len(x1_1_coords), len(x1_2_coords), len(x1_3_coords))
|
| 66 |
+
|
| 67 |
+
# Helper function to select points closest to centroid
|
| 68 |
+
def select_closest_to_centroid(coords, target_size):
|
| 69 |
+
if len(coords) <= target_size:
|
| 70 |
+
return coords
|
| 71 |
+
|
| 72 |
+
# Calculate centroid
|
| 73 |
+
centroid = np.mean(coords, axis=0)
|
| 74 |
+
|
| 75 |
+
# Calculate distances to centroid
|
| 76 |
+
distances = np.linalg.norm(coords - centroid, axis=1)
|
| 77 |
+
|
| 78 |
+
# Get indices of closest points
|
| 79 |
+
closest_indices = np.argsort(distances)[:target_size]
|
| 80 |
+
|
| 81 |
+
return coords[closest_indices]
|
| 82 |
+
|
| 83 |
+
# Sample all endpoint clusters to target size using centroid-based selection
|
| 84 |
+
x1_1_coords = select_closest_to_centroid(x1_1_coords, target_size)
|
| 85 |
+
x1_2_coords = select_closest_to_centroid(x1_2_coords, target_size)
|
| 86 |
+
x1_3_coords = select_closest_to_centroid(x1_3_coords, target_size)
|
| 87 |
+
|
| 88 |
+
dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts()
|
| 89 |
+
|
| 90 |
+
# DMSO (unchanged)
|
| 91 |
+
largest_dmso_cluster = dmso_cluster_counts.index[0]
|
| 92 |
+
dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster]
|
| 93 |
+
|
| 94 |
+
dmso_coords = dmso_cluster_data[pc_cols].values
|
| 95 |
+
|
| 96 |
+
# Random sampling from largest DMSO cluster to match target size
|
| 97 |
+
# For DMSO, we'll also use centroid-based selection for consistency
|
| 98 |
+
if len(dmso_coords) >= target_size:
|
| 99 |
+
x0_coords = select_closest_to_centroid(dmso_coords, target_size)
|
| 100 |
+
else:
|
| 101 |
+
# If largest cluster is smaller than target, use all of it and pad with other DMSO cells
|
| 102 |
+
remaining_needed = target_size - len(dmso_coords)
|
| 103 |
+
other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster]
|
| 104 |
+
other_dmso_coords = other_dmso_data[pc_cols].values
|
| 105 |
+
|
| 106 |
+
if len(other_dmso_coords) >= remaining_needed:
|
| 107 |
+
# Select closest to centroid from other DMSO cells
|
| 108 |
+
other_selected = select_closest_to_centroid(other_dmso_coords, remaining_needed)
|
| 109 |
+
x0_coords = np.vstack([dmso_coords, other_selected])
|
| 110 |
+
else:
|
| 111 |
+
# Use all available DMSO cells and reduce target size
|
| 112 |
+
all_dmso_coords = dmso_data[pc_cols].values
|
| 113 |
+
target_size = min(target_size, len(all_dmso_coords))
|
| 114 |
+
x0_coords = select_closest_to_centroid(all_dmso_coords, target_size)
|
| 115 |
+
|
| 116 |
+
# Re-select endpoint clusters with updated target size
|
| 117 |
+
x1_1_coords = select_closest_to_centroid(x1_1_data[pc_cols].values.astype(float), target_size)
|
| 118 |
+
x1_2_coords = select_closest_to_centroid(x1_2_data[pc_cols].values.astype(float), target_size)
|
| 119 |
+
x1_3_coords = select_closest_to_centroid(x1_3_data[pc_cols].values.astype(float), target_size)
|
| 120 |
+
|
| 121 |
+
# No need to resample since we already selected the right number
|
| 122 |
+
# The endpoint clusters are already at target_size from centroid-based selection
|
| 123 |
+
|
| 124 |
+
self.n_samples = target_size
|
| 125 |
+
|
| 126 |
+
# for plotting
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
x0 = torch.tensor(x0_coords, dtype=torch.float32)
|
| 130 |
+
x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
|
| 131 |
+
x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
|
| 132 |
+
x1_3 = torch.tensor(x1_3_coords, dtype=torch.float32)
|
| 133 |
+
x1 = torch.cat([x1_1, x1_2, x1_3], dim=0)
|
| 134 |
+
|
| 135 |
+
self.coords_t0 = x0
|
| 136 |
+
self.coords_t1 = x1
|
| 137 |
+
|
| 138 |
+
self.time_labels = np.concatenate([
|
| 139 |
+
np.zeros(len(self.coords_t0)), # t=0
|
| 140 |
+
np.ones(len(self.coords_t1)), # t=1
|
| 141 |
+
])
|
| 142 |
+
|
| 143 |
+
split_index = int(target_size * self.split_ratios[0])
|
| 144 |
+
|
| 145 |
+
if target_size - split_index < self.batch_size:
|
| 146 |
+
split_index = target_size - self.batch_size
|
| 147 |
+
|
| 148 |
+
train_x0 = x0[:split_index]
|
| 149 |
+
val_x0 = x0[split_index:]
|
| 150 |
+
train_x1 = x1_1[:split_index]
|
| 151 |
+
val_x1 = x1_1[split_index:]
|
| 152 |
+
|
| 153 |
+
self.val_x0 = val_x0
|
| 154 |
+
|
| 155 |
+
train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
|
| 156 |
+
train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=1.0)
|
| 157 |
+
|
| 158 |
+
val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
|
| 159 |
+
val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=1.0)
|
| 160 |
+
|
| 161 |
+
# Updated train dataloaders to include x1_3
|
| 162 |
+
self.train_dataloaders = {
|
| 163 |
+
"x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 164 |
+
"x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
# Updated val dataloaders to include x1_3
|
| 168 |
+
self.val_dataloaders = {
|
| 169 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
|
| 170 |
+
"x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
all_coords = df[pc_cols].dropna().values.astype(float)
|
| 174 |
+
self.dataset = torch.tensor(all_coords, dtype=torch.float32)
|
| 175 |
+
self.tree = cKDTree(all_coords)
|
| 176 |
+
|
| 177 |
+
self.test_dataloaders = {
|
| 178 |
+
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
|
| 179 |
+
"dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
# Updated metric samples - now using 4 clusters instead of 3
|
| 183 |
+
#km_all = KMeans(n_clusters=4, random_state=42).fit(self.dataset.numpy())
|
| 184 |
+
km_all = KMeans(n_clusters=2, random_state=0).fit(self.dataset[:, :3].numpy())
|
| 185 |
+
|
| 186 |
+
cluster_labels = km_all.labels_
|
| 187 |
+
|
| 188 |
+
cluster_0_mask = cluster_labels == 0
|
| 189 |
+
cluster_1_mask = cluster_labels == 1
|
| 190 |
+
|
| 191 |
+
samples = self.dataset.cpu().numpy()
|
| 192 |
+
|
| 193 |
+
cluster_0_data = samples[cluster_0_mask]
|
| 194 |
+
cluster_1_data = samples[cluster_1_mask]
|
| 195 |
+
|
| 196 |
+
self.metric_samples_dataloaders = [
|
| 197 |
+
DataLoader(
|
| 198 |
+
torch.tensor(cluster_1_data, dtype=torch.float32),
|
| 199 |
+
batch_size=cluster_1_data.shape[0],
|
| 200 |
+
shuffle=False,
|
| 201 |
+
drop_last=False,
|
| 202 |
+
),
|
| 203 |
+
DataLoader(
|
| 204 |
+
torch.tensor(cluster_0_data, dtype=torch.float32),
|
| 205 |
+
batch_size=cluster_0_data.shape[0],
|
| 206 |
+
shuffle=False,
|
| 207 |
+
drop_last=False,
|
| 208 |
+
),
|
| 209 |
+
]
|
| 210 |
+
|
| 211 |
+
def train_dataloader(self):
|
| 212 |
+
combined_loaders = {
|
| 213 |
+
"train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
|
| 214 |
+
"metric_samples": CombinedLoader(
|
| 215 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 216 |
+
),
|
| 217 |
+
}
|
| 218 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 219 |
+
|
| 220 |
+
def val_dataloader(self):
|
| 221 |
+
combined_loaders = {
|
| 222 |
+
"val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
|
| 223 |
+
"metric_samples": CombinedLoader(
|
| 224 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 225 |
+
),
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def test_dataloader(self):
|
| 233 |
+
combined_loaders = {
|
| 234 |
+
"test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
|
| 235 |
+
"metric_samples": CombinedLoader(
|
| 236 |
+
self.metric_samples_dataloaders, mode="min_size"
|
| 237 |
+
),
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
return CombinedLoader(combined_loaders, mode="max_size_cycle")
|
| 241 |
+
|
| 242 |
+
def get_manifold_proj(self, points):
|
| 243 |
+
"""Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
|
| 244 |
+
return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
|
| 245 |
+
|
| 246 |
+
@staticmethod
|
| 247 |
+
def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
|
| 248 |
+
"""
|
| 249 |
+
Apply local smoothing based on k-nearest neighbors in the full dataset
|
| 250 |
+
This replaces the plane projection for 2D manifold regularization
|
| 251 |
+
"""
|
| 252 |
+
points_np = x.detach().cpu().numpy()
|
| 253 |
+
_, idx = tree.query(points_np, k=k)
|
| 254 |
+
nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
|
| 255 |
+
|
| 256 |
+
# Compute weighted average of neighbors
|
| 257 |
+
dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
|
| 258 |
+
weights = torch.exp(-dists / temp)
|
| 259 |
+
weights = weights / weights.sum(dim=1, keepdim=True)
|
| 260 |
+
|
| 261 |
+
# Weighted average of neighbors
|
| 262 |
+
smoothed = (weights * nearest_pts).sum(dim=1)
|
| 263 |
+
|
| 264 |
+
# Blend original point with smoothed version
|
| 265 |
+
alpha = 0.3 # How much smoothing to apply
|
| 266 |
+
return (1 - alpha) * x + alpha * smoothed
|
| 267 |
+
|
| 268 |
+
def get_timepoint_data(self):
|
| 269 |
+
"""Return data organized by timepoints for visualization"""
|
| 270 |
+
return {
|
| 271 |
+
't0': self.coords_t0,
|
| 272 |
+
't1': self.coords_t1,
|
| 273 |
+
'time_labels': self.time_labels
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
def get_datamodule():
|
| 277 |
+
datamodule = TrametinibSingleBranchDataModule(args)
|
| 278 |
+
datamodule.setup(stage="fit")
|
| 279 |
+
return datamodule
|
losses/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
losses/energy_loss.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, math, numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torchdiffeq import odeint as odeint2
|
| 5 |
+
from torchmetrics.functional import mean_squared_error
|
| 6 |
+
import ot
|
| 7 |
+
|
| 8 |
+
class EnergySolver(nn.Module):
|
| 9 |
+
def __init__(self, flow_net, growth_net, state_cost, data_manifold_metric=None, samples=None):
|
| 10 |
+
super(EnergySolver, self).__init__()
|
| 11 |
+
self.flow_net = flow_net
|
| 12 |
+
self.growth_net = growth_net
|
| 13 |
+
self.state_cost = state_cost
|
| 14 |
+
|
| 15 |
+
self.data_manifold_metric = data_manifold_metric
|
| 16 |
+
self.samples = samples
|
| 17 |
+
|
| 18 |
+
def forward(self, t, state):
|
| 19 |
+
xt, wt, mt = state
|
| 20 |
+
|
| 21 |
+
xt.requires_grad_(True)
|
| 22 |
+
wt.requires_grad_(True)
|
| 23 |
+
mt.requires_grad_(True)
|
| 24 |
+
|
| 25 |
+
t.requires_grad_(True)
|
| 26 |
+
|
| 27 |
+
ut = self.flow_net(t, xt)
|
| 28 |
+
gt = self.growth_net(t, xt)
|
| 29 |
+
|
| 30 |
+
time=t.expand(xt.shape[0], 1)
|
| 31 |
+
time.requires_grad_(True)
|
| 32 |
+
|
| 33 |
+
dx_dt = ut
|
| 34 |
+
dw_dt = gt
|
| 35 |
+
|
| 36 |
+
if self.data_manifold_metric is not None:
|
| 37 |
+
vel, _, _ = self.data_manifold_metric.calculate_velocity(
|
| 38 |
+
xt, ut, self.samples, 0
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
dm_dt = torch.mean(vel ** 2) * wt
|
| 42 |
+
else:
|
| 43 |
+
dm_dt = ((ut**2).sum(dim =-1) + self.state_cost(xt)) * wt
|
| 44 |
+
|
| 45 |
+
assert xt.shape == dx_dt.shape, f"dx mismatch: expected {xt.shape}, got {dx_dt.shape}"
|
| 46 |
+
assert wt.shape == dw_dt.shape, f"dw mismatch: expected {wt.shape}, got {dw_dt.shape}"
|
| 47 |
+
assert mt.shape == dm_dt.shape, f"dm mismatch: expected {mt.shape}, got {dm_dt.shape}"
|
| 48 |
+
return dx_dt, dw_dt, dm_dt
|
| 49 |
+
|
| 50 |
+
class ReconsLoss(nn.Module):
|
| 51 |
+
def __init__(self, hinge_value=0.01):
|
| 52 |
+
super(ReconsLoss, self).__init__()
|
| 53 |
+
self.hinge_value = hinge_value
|
| 54 |
+
|
| 55 |
+
def __call__(self, source, target, groups = None, to_ignore = None, top_k = 5):
|
| 56 |
+
if groups is not None:
|
| 57 |
+
# for global loss
|
| 58 |
+
c_dist = torch.stack([
|
| 59 |
+
torch.cdist(source[i], target[i])
|
| 60 |
+
|
| 61 |
+
for i in range(1,len(groups))
|
| 62 |
+
if groups[i] != to_ignore
|
| 63 |
+
])
|
| 64 |
+
else:
|
| 65 |
+
# for local loss
|
| 66 |
+
c_dist = torch.stack([
|
| 67 |
+
torch.cdist(source, target)
|
| 68 |
+
])
|
| 69 |
+
values, _ = torch.topk(c_dist, top_k, dim=2, largest=False, sorted=False)
|
| 70 |
+
values -= self.hinge_value
|
| 71 |
+
values[values<0] = 0
|
| 72 |
+
loss = torch.mean(values)
|
| 73 |
+
return loss
|
networks/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
networks/flow_mlp.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append("./BranchSBM")
|
| 3 |
+
import torch
|
| 4 |
+
from networks.mlp_base import SimpleDenseNet
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class VelocityNet(SimpleDenseNet):
|
| 8 |
+
def __init__(self, dim: int, *args, **kwargs):
|
| 9 |
+
super().__init__(input_size=dim + 1, target_size=dim, *args, **kwargs)
|
| 10 |
+
|
| 11 |
+
def forward(self, t, x):
|
| 12 |
+
|
| 13 |
+
if t.dim() < 1 or t.shape[0] != x.shape[0]:
|
| 14 |
+
t = t.repeat(x.shape[0])[:, None]
|
| 15 |
+
if t.dim() < 2:
|
| 16 |
+
t = t[:, None]
|
| 17 |
+
x = torch.cat([t, x], dim=-1)
|
| 18 |
+
return self.model(x)
|
networks/growth_mlp.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append("./BranchSBM")
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from typing import List, Optional
|
| 6 |
+
|
| 7 |
+
from networks.mlp_base import SimpleDenseNet
|
| 8 |
+
|
| 9 |
+
class GrowthNet(SimpleDenseNet):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
dim: int,
|
| 13 |
+
activation: str,
|
| 14 |
+
hidden_dims: List[int] = None,
|
| 15 |
+
batch_norm: bool = False,
|
| 16 |
+
negative: bool = False
|
| 17 |
+
):
|
| 18 |
+
super().__init__(input_size=dim + 1, target_size=1,
|
| 19 |
+
activation=activation,
|
| 20 |
+
batch_norm=batch_norm,
|
| 21 |
+
hidden_dims=hidden_dims)
|
| 22 |
+
|
| 23 |
+
self.softplus = nn.Softplus()
|
| 24 |
+
self.negative = negative
|
| 25 |
+
|
| 26 |
+
def forward(self, t, x):
|
| 27 |
+
|
| 28 |
+
if t.dim() < 1 or t.shape[0] != x.shape[0]:
|
| 29 |
+
t = t.repeat(x.shape[0])[:, None]
|
| 30 |
+
if t.dim() < 2:
|
| 31 |
+
t = t[:, None]
|
| 32 |
+
x = torch.cat([t, x], dim=-1)
|
| 33 |
+
x = self.model(x)
|
| 34 |
+
x = self.softplus(self.model(x))
|
| 35 |
+
if self.negative:
|
| 36 |
+
x = -x
|
| 37 |
+
return x
|
networks/interpolant_mlp.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append("./BranchSBM")
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from typing import List, Optional
|
| 6 |
+
|
| 7 |
+
from networks.mlp_base import SimpleDenseNet
|
| 8 |
+
|
| 9 |
+
class GeoPathMLP(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
input_dim: int,
|
| 13 |
+
activation: str,
|
| 14 |
+
batch_norm: bool = True,
|
| 15 |
+
hidden_dims: Optional[List[int]] = None,
|
| 16 |
+
time_geopath: bool = False,
|
| 17 |
+
):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.input_dim = input_dim
|
| 20 |
+
self.time_geopath = time_geopath
|
| 21 |
+
self.mainnet = SimpleDenseNet(
|
| 22 |
+
input_size=2 * input_dim + (1 if time_geopath else 0),
|
| 23 |
+
target_size=input_dim,
|
| 24 |
+
activation=activation,
|
| 25 |
+
batch_norm=batch_norm,
|
| 26 |
+
hidden_dims=hidden_dims,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def forward(
|
| 30 |
+
self, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor
|
| 31 |
+
) -> torch.Tensor:
|
| 32 |
+
x = torch.cat([x0, x1], dim=1)
|
| 33 |
+
if self.time_geopath:
|
| 34 |
+
x = torch.cat([x, t], dim=1)
|
| 35 |
+
return self.mainnet(x)
|
networks/mlp_base.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append("./BranchSBM")
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch
|
| 5 |
+
from typing import List, Optional
|
| 6 |
+
|
| 7 |
+
class swish(nn.Module):
|
| 8 |
+
def forward(self, x):
|
| 9 |
+
return x * torch.sigmoid(x)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
ACTIVATION_MAP = {
|
| 13 |
+
"relu": nn.ReLU,
|
| 14 |
+
"sigmoid": nn.Sigmoid,
|
| 15 |
+
"tanh": nn.Tanh,
|
| 16 |
+
"selu": nn.SELU,
|
| 17 |
+
"elu": nn.ELU,
|
| 18 |
+
"lrelu": nn.LeakyReLU,
|
| 19 |
+
"softplus": nn.Softplus,
|
| 20 |
+
"silu": nn.SiLU,
|
| 21 |
+
"swish": swish,
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class SimpleDenseNet(nn.Module):
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
input_size: int,
|
| 29 |
+
target_size: int,
|
| 30 |
+
activation: str,
|
| 31 |
+
batch_norm: bool = False,
|
| 32 |
+
hidden_dims: List[int] = None,
|
| 33 |
+
):
|
| 34 |
+
super().__init__()
|
| 35 |
+
dims = [input_size, *hidden_dims, target_size]
|
| 36 |
+
layers = []
|
| 37 |
+
for i in range(len(dims) - 2):
|
| 38 |
+
layers.append(nn.Linear(dims[i], dims[i + 1]))
|
| 39 |
+
if batch_norm:
|
| 40 |
+
layers.append(nn.BatchNorm1d(dims[i + 1]))
|
| 41 |
+
layers.append(ACTIVATION_MAP[activation]())
|
| 42 |
+
layers.append(nn.Linear(dims[-2], dims[-1]))
|
| 43 |
+
self.model = nn.Sequential(*layers)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
return self.model(x)
|
networks/utils.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append("./BranchSBM")
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
class flow_model_torch_wrapper(torch.nn.Module):
|
| 6 |
+
"""Wraps model to torchdyn compatible format."""
|
| 7 |
+
|
| 8 |
+
def __init__(self, model):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.model = model
|
| 11 |
+
|
| 12 |
+
def forward(self, t, x, *args, **kwargs):
|
| 13 |
+
return self.model(t, x)
|
state_costs/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
state_costs/land.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def weighting_function(x, samples, gamma):
|
| 5 |
+
pairwise_sq_diff = (x[:, None, :] - samples[None, :, :]) ** 2
|
| 6 |
+
pairwise_sq_dist = pairwise_sq_diff.sum(-1)
|
| 7 |
+
weights = torch.exp(-pairwise_sq_dist / (2 * gamma**2))
|
| 8 |
+
return weights
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def land_metric_tensor(x, samples, gamma, rho):
|
| 12 |
+
weights = weighting_function(x, samples, gamma) # Shape [B, N]
|
| 13 |
+
differences = samples[None, :, :] - x[:, None, :] # Shape [B, N, D]
|
| 14 |
+
squared_differences = differences**2 # Shape [B, N, D]
|
| 15 |
+
|
| 16 |
+
# Compute the sum of weighted squared differences for each dimension
|
| 17 |
+
M_dd_diag = torch.einsum("bn,bnd->bd", weights, squared_differences) + rho
|
| 18 |
+
|
| 19 |
+
# Invert the metric tensor diagonal for each x_t
|
| 20 |
+
M_dd_inv_diag = 1.0 / M_dd_diag # Shape [B, D] since it's diagonal
|
| 21 |
+
return M_dd_inv_diag
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def weighting_function_dt(x, dx_dt, samples, gamma, weights):
|
| 25 |
+
pairwise_sq_diff_dt = (x[:, None, :] - samples[None, :, :]) * dx_dt[:, None, :]
|
| 26 |
+
return -pairwise_sq_diff_dt.sum(-1) * weights / (gamma**2)
|
state_costs/metric_factory.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append("./BranchSBM")
|
| 3 |
+
import torch
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 6 |
+
from torch.utils.data import Dataset, DataLoader
|
| 7 |
+
|
| 8 |
+
from state_costs.land import land_metric_tensor
|
| 9 |
+
from state_costs.rbf import RBFNetwork
|
| 10 |
+
|
| 11 |
+
class DataManifoldMetric:
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
args,
|
| 15 |
+
skipped_time_points=None,
|
| 16 |
+
datamodule=None,
|
| 17 |
+
):
|
| 18 |
+
self.skipped_time_points = skipped_time_points
|
| 19 |
+
self.datamodule = datamodule
|
| 20 |
+
|
| 21 |
+
self.gamma = args.gamma_current
|
| 22 |
+
self.rho = args.rho
|
| 23 |
+
self.metric = args.velocity_metric
|
| 24 |
+
self.n_centers = args.n_centers
|
| 25 |
+
self.kappa = args.kappa
|
| 26 |
+
self.metric_epochs = args.metric_epochs
|
| 27 |
+
self.metric_patience = args.metric_patience
|
| 28 |
+
self.lr = args.metric_lr
|
| 29 |
+
self.alpha_metric = args.alpha_metric
|
| 30 |
+
self.image_data = args.data_type == "image"
|
| 31 |
+
self.accelerator = args.accelerator
|
| 32 |
+
|
| 33 |
+
self.called_first_time = True
|
| 34 |
+
self.args = args
|
| 35 |
+
|
| 36 |
+
def calculate_metric(self, x_t, samples, current_timestep):
|
| 37 |
+
if self.metric == "land":
|
| 38 |
+
M_dd_x_t = (
|
| 39 |
+
land_metric_tensor(x_t, samples, self.gamma, self.rho)
|
| 40 |
+
** self.alpha_metric
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
elif self.metric == "rbf":
|
| 44 |
+
if self.called_first_time:
|
| 45 |
+
self.rbf_networks = []
|
| 46 |
+
for timestep in range(self.datamodule.num_timesteps - 1):
|
| 47 |
+
if timestep in self.skipped_time_points:
|
| 48 |
+
continue
|
| 49 |
+
print("Learning RBF networks, timestep: ", timestep)
|
| 50 |
+
rbf_network = RBFNetwork(
|
| 51 |
+
current_timestep=timestep,
|
| 52 |
+
next_timestep=timestep
|
| 53 |
+
+ 1
|
| 54 |
+
+ (1 if timestep + 1 in self.skipped_time_points else 0),
|
| 55 |
+
n_centers=self.n_centers,
|
| 56 |
+
kappa=self.kappa,
|
| 57 |
+
lr=self.lr,
|
| 58 |
+
datamodule=self.datamodule,
|
| 59 |
+
args=self.args
|
| 60 |
+
)
|
| 61 |
+
early_stop_callback = pl.callbacks.EarlyStopping(
|
| 62 |
+
monitor="MetricModel/val_loss_learn_metric",
|
| 63 |
+
patience=self.metric_patience,
|
| 64 |
+
mode="min",
|
| 65 |
+
)
|
| 66 |
+
trainer = pl.Trainer(
|
| 67 |
+
max_epochs=self.metric_epochs,
|
| 68 |
+
accelerator=self.accelerator,
|
| 69 |
+
logger=WandbLogger(),
|
| 70 |
+
num_sanity_val_steps=0,
|
| 71 |
+
callbacks=(
|
| 72 |
+
[early_stop_callback] if not self.image_data else None
|
| 73 |
+
),
|
| 74 |
+
)
|
| 75 |
+
if self.image_data:
|
| 76 |
+
self.dataloader = DataLoader(
|
| 77 |
+
self.datamodule.all_data,
|
| 78 |
+
batch_size=128,
|
| 79 |
+
shuffle=True,
|
| 80 |
+
)
|
| 81 |
+
trainer.fit(rbf_network, self.dataloader)
|
| 82 |
+
else:
|
| 83 |
+
trainer.fit(rbf_network, self.datamodule)
|
| 84 |
+
self.rbf_networks.append(rbf_network)
|
| 85 |
+
self.called_first_time = False
|
| 86 |
+
print("Learning RBF networksss... Done")
|
| 87 |
+
M_dd_x_t = self.rbf_networks[current_timestep].compute_metric(
|
| 88 |
+
x_t,
|
| 89 |
+
epsilon=self.rho,
|
| 90 |
+
alpha=self.alpha_metric,
|
| 91 |
+
image_hx=self.image_data,
|
| 92 |
+
)
|
| 93 |
+
return M_dd_x_t
|
| 94 |
+
|
| 95 |
+
def calculate_velocity(self, x_t, u_t, samples, timestep):
|
| 96 |
+
|
| 97 |
+
if len(u_t.shape) > 2:
|
| 98 |
+
u_t = u_t.reshape(u_t.shape[0], -1)
|
| 99 |
+
x_t = x_t.reshape(x_t.shape[0], -1)
|
| 100 |
+
M_dd_x_t = self.calculate_metric(x_t, samples, timestep).to(u_t.device)
|
| 101 |
+
|
| 102 |
+
velocity = torch.sqrt(((u_t**2) * M_dd_x_t).sum(dim=-1))
|
| 103 |
+
ut_sum = (u_t**2).sum(dim=-1)
|
| 104 |
+
metric_sum = M_dd_x_t.sum(dim=-1)
|
| 105 |
+
return velocity, ut_sum, metric_sum
|
state_costs/rbf.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytorch_lightning as pl
|
| 2 |
+
import torch
|
| 3 |
+
from sklearn.cluster import KMeans
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
class RBFNetwork(pl.LightningModule):
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
current_timestep,
|
| 10 |
+
next_timestep,
|
| 11 |
+
n_centers: int = 100,
|
| 12 |
+
kappa: float = 1.0,
|
| 13 |
+
lr=1e-2,
|
| 14 |
+
datamodule=None,
|
| 15 |
+
image_data=False,
|
| 16 |
+
args=None
|
| 17 |
+
):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.K = n_centers
|
| 20 |
+
self.current_timestep = current_timestep
|
| 21 |
+
self.next_timestep = next_timestep
|
| 22 |
+
self.clustering_model = KMeans(n_clusters=self.K)
|
| 23 |
+
self.kappa = kappa
|
| 24 |
+
self.last_val_loss = 1
|
| 25 |
+
self.lr = lr
|
| 26 |
+
self.W = torch.nn.Parameter(torch.rand(self.K, 1))
|
| 27 |
+
self.datamodule = datamodule
|
| 28 |
+
self.image_data = image_data
|
| 29 |
+
self.args = args
|
| 30 |
+
|
| 31 |
+
def on_before_zero_grad(self, *args, **kwargs):
|
| 32 |
+
self.W.data = torch.clamp(self.W.data, min=0.0001)
|
| 33 |
+
|
| 34 |
+
def on_train_start(self):
|
| 35 |
+
with torch.no_grad():
|
| 36 |
+
|
| 37 |
+
batch = next(iter(self.trainer.datamodule.train_dataloader()))
|
| 38 |
+
|
| 39 |
+
metric_samples = batch[0]["metric_samples"][0]
|
| 40 |
+
all_data = torch.cat(metric_samples)
|
| 41 |
+
data_to_fit = all_data
|
| 42 |
+
|
| 43 |
+
print("Fitting Clustering model...")
|
| 44 |
+
self.clustering_model.fit(data_to_fit)
|
| 45 |
+
|
| 46 |
+
clusters = (
|
| 47 |
+
self.calculate_centroids(all_data, self.clustering_model.labels_)
|
| 48 |
+
if self.image_data
|
| 49 |
+
else self.clustering_model.cluster_centers_
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
self.C = torch.tensor(clusters, dtype=torch.float32).to(self.device)
|
| 53 |
+
labels = self.clustering_model.labels_
|
| 54 |
+
sigmas = np.zeros((self.K, 1))
|
| 55 |
+
|
| 56 |
+
for k in range(self.K):
|
| 57 |
+
points = all_data[labels == k, :]
|
| 58 |
+
variance = ((points - clusters[k]) ** 2).mean(axis=0)
|
| 59 |
+
sigmas[k, :] = np.sqrt(
|
| 60 |
+
variance.sum() if self.image_data else variance.mean()
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
self.lamda = torch.tensor(
|
| 64 |
+
0.5 / (self.kappa * sigmas) ** 2, dtype=torch.float32
|
| 65 |
+
).to(self.device)
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
if len(x.shape) > 2:
|
| 69 |
+
x = x.reshape(x.shape[0], -1).to(self.C.device)
|
| 70 |
+
|
| 71 |
+
x = x.to(self.C.device)
|
| 72 |
+
dist2 = torch.cdist(x, self.C) ** 2
|
| 73 |
+
self.phi_x = torch.exp(-0.5 * self.lamda[None, :, :] * dist2[:, :, None])
|
| 74 |
+
|
| 75 |
+
h_x = (self.W.to(x.device) * self.phi_x).sum(dim=1)
|
| 76 |
+
|
| 77 |
+
return h_x
|
| 78 |
+
|
| 79 |
+
def training_step(self, batch, batch_idx):
|
| 80 |
+
if self.args.data_type == "scrna" or self.args.data_type == "tahoe":
|
| 81 |
+
main_batch = batch[0]["train_samples"][0]
|
| 82 |
+
else:
|
| 83 |
+
main_batch = batch["train_samples"][0]
|
| 84 |
+
|
| 85 |
+
x0 = main_batch["x0"][0]
|
| 86 |
+
if self.args.branches == 1:
|
| 87 |
+
x1 = main_batch["x1"][0]
|
| 88 |
+
inputs = torch.cat([x0, x1], dim=0).to(self.device)
|
| 89 |
+
else:
|
| 90 |
+
x1_1 = main_batch["x1_1"][0]
|
| 91 |
+
x1_2 = main_batch["x1_2"][0]
|
| 92 |
+
|
| 93 |
+
inputs = torch.cat([x0, x1_1, x1_2], dim=0).to(self.device)
|
| 94 |
+
print("inputs shape")
|
| 95 |
+
print(inputs.shape)
|
| 96 |
+
|
| 97 |
+
loss = ((1 - self.forward(inputs)) ** 2).mean()
|
| 98 |
+
self.log(
|
| 99 |
+
"MetricModel/train_loss_learn_metric",
|
| 100 |
+
loss,
|
| 101 |
+
on_step=True,
|
| 102 |
+
on_epoch=True,
|
| 103 |
+
prog_bar=True,
|
| 104 |
+
)
|
| 105 |
+
return loss
|
| 106 |
+
|
| 107 |
+
def validation_step(self, batch, batch_idx):
|
| 108 |
+
if self.args.data_type == "scrna" or self.args.data_type == "tahoe":
|
| 109 |
+
main_batch = batch[0]["val_samples"][0]
|
| 110 |
+
else:
|
| 111 |
+
main_batch = batch["val_samples"][0]
|
| 112 |
+
|
| 113 |
+
x0 = main_batch["x0"][0]
|
| 114 |
+
if self.args.branches == 1:
|
| 115 |
+
x1 = main_batch["x1"][0]
|
| 116 |
+
inputs = torch.cat([x0, x1], dim=0).to(self.device)
|
| 117 |
+
else:
|
| 118 |
+
x1_1 = main_batch["x1_1"][0]
|
| 119 |
+
x1_2 = main_batch["x1_2"][0]
|
| 120 |
+
|
| 121 |
+
inputs = torch.cat([x0, x1_1, x1_2], dim=0).to(self.device)
|
| 122 |
+
|
| 123 |
+
h = self.forward(inputs)
|
| 124 |
+
|
| 125 |
+
loss = ((1 - h) ** 2).mean()
|
| 126 |
+
self.log(
|
| 127 |
+
"MetricModel/val_loss_learn_metric",
|
| 128 |
+
loss,
|
| 129 |
+
on_step=True,
|
| 130 |
+
on_epoch=True,
|
| 131 |
+
prog_bar=True,
|
| 132 |
+
)
|
| 133 |
+
self.last_val_loss = loss.detach()
|
| 134 |
+
return loss
|
| 135 |
+
|
| 136 |
+
def calculate_centroids(self, all_data, labels):
|
| 137 |
+
unique_labels = np.unique(labels)
|
| 138 |
+
centroids = np.zeros((len(unique_labels), all_data.shape[1]))
|
| 139 |
+
for i, label in enumerate(unique_labels):
|
| 140 |
+
centroids[i] = all_data[labels == label].mean(axis=0)
|
| 141 |
+
return centroids
|
| 142 |
+
|
| 143 |
+
def configure_optimizers(self):
|
| 144 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
| 145 |
+
return optimizer
|
| 146 |
+
|
| 147 |
+
def compute_metric(self, x, alpha=1, epsilon=1e-2, image_hx=False):
|
| 148 |
+
if epsilon < 0:
|
| 149 |
+
epsilon = (1 - self.last_val_loss.item()) / abs(epsilon)
|
| 150 |
+
h_x = self.forward(x)
|
| 151 |
+
if image_hx:
|
| 152 |
+
h_x = 1 - torch.abs(1 - h_x)
|
| 153 |
+
M_x = 1 / (h_x**alpha + epsilon)
|
| 154 |
+
else:
|
| 155 |
+
M_x = 1 / (h_x + epsilon) ** alpha
|
| 156 |
+
return M_x
|
train/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
train/main_branches.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append("./BranchSBM")
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import argparse
|
| 6 |
+
import copy
|
| 7 |
+
from pytorch_lightning import Trainer
|
| 8 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 9 |
+
import wandb
|
| 10 |
+
import hydra
|
| 11 |
+
from omegaconf import DictConfig, OmegaConf
|
| 12 |
+
from torchcfm.optimal_transport import OTPlanSampler
|
| 13 |
+
|
| 14 |
+
from branchsbm.branchsbm import BranchSBM
|
| 15 |
+
from branchsbm.branch_flow_net_train import FlowNetTrainCell, FlowNetTrainLidar
|
| 16 |
+
from branchsbm.branch_interpolant_train import BranchInterpolantTrain
|
| 17 |
+
|
| 18 |
+
from dataloaders.trajectory_data import TemporalDataModule
|
| 19 |
+
from dataloaders.mouse_data import WeightedBranchedCellDataModule
|
| 20 |
+
from dataloaders.three_branch_data import ThreeBranchTahoeDataModule
|
| 21 |
+
from dataloaders.clonidine_v2_data import ClonidineV2DataModule
|
| 22 |
+
from dataloaders.clonidine_single_branch import ClonidineSingleBranchDataModule
|
| 23 |
+
from dataloaders.trametinib_single import TrametinibSingleBranchDataModule
|
| 24 |
+
from dataloaders.lidar_data import WeightedBranchedLidarDataModule
|
| 25 |
+
from dataloaders.lidar_data_single import LidarSingleDataModule
|
| 26 |
+
|
| 27 |
+
from networks.flow_networks.mlp import VelocityNet
|
| 28 |
+
from networks.growth_networks.mlp import GrowthNet
|
| 29 |
+
from networks.geopath_networks.mlp import GeoPathMLP
|
| 30 |
+
from networks.unet_base import UNetModelWrapper as UNetModel
|
| 31 |
+
from networks.geopath_networks.unet import GeoPathUNet
|
| 32 |
+
from utils import set_seed
|
| 33 |
+
|
| 34 |
+
from train.parsers import parse_args
|
| 35 |
+
from flow_matchers.ema import EMA
|
| 36 |
+
from train.train_utils import (
|
| 37 |
+
load_config,
|
| 38 |
+
merge_config,
|
| 39 |
+
generate_group_string,
|
| 40 |
+
dataset_name2datapath,
|
| 41 |
+
create_callbacks,
|
| 42 |
+
)
|
| 43 |
+
from geo_metrics.metric_factory import DataManifoldMetric
|
| 44 |
+
import torch.nn as nn
|
| 45 |
+
from flow_matchers.branch_growth_net_train import GrowthNetTrain, GrowthNetTrainCell, GrowthNetTrainLidar
|
| 46 |
+
|
| 47 |
+
def main(args: argparse.Namespace, seed: int, t_exclude: int) -> None:
|
| 48 |
+
set_seed(seed)
|
| 49 |
+
branches = args.branches
|
| 50 |
+
|
| 51 |
+
skipped_time_points = [t_exclude] if t_exclude else []
|
| 52 |
+
|
| 53 |
+
### DATAMODULES ###
|
| 54 |
+
if args.data_name == "lidar":
|
| 55 |
+
datamodule = WeightedBranchedLidarDataModule(args=args)
|
| 56 |
+
elif args.data_name == "lidarsingle":
|
| 57 |
+
datamodule = LidarSingleDataModule(args=args)
|
| 58 |
+
elif args.data_name == "mouse":
|
| 59 |
+
datamodule = WeightedBranchedCellDataModule(args=args)
|
| 60 |
+
elif args.data_name in ["clonidine50D", "clonidine100D", "clonidine150D"]:
|
| 61 |
+
datamodule = ClonidineV2DataModule(args=args)
|
| 62 |
+
elif args.data_name == "clonidine50Dsingle":
|
| 63 |
+
datamodule = ClonidineSingleBranchDataModule(args=args)
|
| 64 |
+
elif args.data_name == "trametinib":
|
| 65 |
+
datamodule = ThreeBranchTahoeDataModule(args=args)
|
| 66 |
+
elif args.data_name == "trametinibsingle":
|
| 67 |
+
datamodule = TrametinibSingleBranchDataModule(args=args)
|
| 68 |
+
|
| 69 |
+
flow_nets = nn.ModuleList()
|
| 70 |
+
geopath_nets = nn.ModuleList()
|
| 71 |
+
growth_nets = nn.ModuleList()
|
| 72 |
+
|
| 73 |
+
##### initialize branched flow and growth networks #####
|
| 74 |
+
for i in range(branches):
|
| 75 |
+
flow_net = VelocityNet(
|
| 76 |
+
dim=args.dim,
|
| 77 |
+
hidden_dims=args.hidden_dims_flow,
|
| 78 |
+
activation=args.activation_flow,
|
| 79 |
+
batch_norm=False,
|
| 80 |
+
)
|
| 81 |
+
geopath_net = GeoPathMLP(
|
| 82 |
+
input_dim=args.dim,
|
| 83 |
+
hidden_dims=args.hidden_dims_geopath,
|
| 84 |
+
time_geopath=args.time_geopath,
|
| 85 |
+
activation=args.activation_geopath,
|
| 86 |
+
batch_norm=False,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
if i == 0:
|
| 90 |
+
growth_net = GrowthNet(
|
| 91 |
+
dim=args.dim,
|
| 92 |
+
hidden_dims=args.hidden_dims_growth,
|
| 93 |
+
activation=args.activation_growth,
|
| 94 |
+
batch_norm=False,
|
| 95 |
+
negative=True
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
growth_net = GrowthNet(
|
| 99 |
+
dim=args.dim,
|
| 100 |
+
hidden_dims=args.hidden_dims_growth,
|
| 101 |
+
activation=args.activation_growth,
|
| 102 |
+
batch_norm=False,
|
| 103 |
+
negative=False
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
if args.ema_decay is not None:
|
| 107 |
+
flow_net = EMA(model=flow_net, decay=args.ema_decay)
|
| 108 |
+
geopath_net = EMA(model=geopath_net, decay=args.ema_decay)
|
| 109 |
+
growth_net = EMA(model=growth_net, decay=args.ema_decay)
|
| 110 |
+
|
| 111 |
+
flow_nets.append(flow_net)
|
| 112 |
+
geopath_nets.append(geopath_net)
|
| 113 |
+
growth_nets.append(growth_net)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
ot_sampler = (
|
| 117 |
+
OTPlanSampler(method=args.optimal_transport_method)
|
| 118 |
+
if args.optimal_transport_method != "None"
|
| 119 |
+
else None
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
wandb.init(
|
| 123 |
+
project=f"branchsbm-{args.data_name}-{branches}-branches",
|
| 124 |
+
group=args.group_name,
|
| 125 |
+
config=vars(args),
|
| 126 |
+
dir=args.working_dir,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
flow_matcher_base = BranchSBM(
|
| 130 |
+
geopath_nets=geopath_nets,
|
| 131 |
+
sigma=args.sigma,
|
| 132 |
+
alpha=int(args.branchsbm),
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
##### STAGE 1: Training of Geodesic Interpolants Beginning #####
|
| 136 |
+
|
| 137 |
+
geopath_callbacks = create_callbacks(
|
| 138 |
+
args, phase="geopath", data_type=args.data_type, run_id=wandb.run.id
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# define state cost
|
| 142 |
+
data_manifold_metric = DataManifoldMetric(
|
| 143 |
+
args=args,
|
| 144 |
+
skipped_time_points=skipped_time_points,
|
| 145 |
+
datamodule=datamodule,
|
| 146 |
+
)
|
| 147 |
+
geopath_model = BranchInterpolantTrain(
|
| 148 |
+
flow_matcher=flow_matcher_base,
|
| 149 |
+
skipped_time_points=skipped_time_points,
|
| 150 |
+
ot_sampler=ot_sampler,
|
| 151 |
+
args=args,
|
| 152 |
+
data_manifold_metric=data_manifold_metric
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
wandb_logger = WandbLogger()
|
| 156 |
+
|
| 157 |
+
trainer = Trainer(
|
| 158 |
+
max_epochs=args.epochs,
|
| 159 |
+
callbacks=geopath_callbacks,
|
| 160 |
+
accelerator=args.accelerator,
|
| 161 |
+
logger=wandb_logger,
|
| 162 |
+
num_sanity_val_steps=0,
|
| 163 |
+
default_root_dir=args.working_dir,
|
| 164 |
+
gradient_clip_val=(1.0 if args.data_type == "image" else None),
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if args.load_geopath_model_ckpt:
|
| 168 |
+
best_model_path = args.load_geopath_model_ckpt
|
| 169 |
+
else:
|
| 170 |
+
trainer.fit(
|
| 171 |
+
geopath_model,
|
| 172 |
+
datamodule=datamodule,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
best_model_path = geopath_callbacks[0].best_model_path
|
| 176 |
+
|
| 177 |
+
geopath_model = BranchInterpolantTrain.load_from_checkpoint(best_model_path)
|
| 178 |
+
|
| 179 |
+
flow_matcher_base.geopath_nets = geopath_model.geopath_nets
|
| 180 |
+
|
| 181 |
+
##### STAGE 1: Training of Geodesic Interpolants End #####
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
##### STAGE 2: Flow Matching Beginning #####
|
| 185 |
+
flow_callbacks = create_callbacks(
|
| 186 |
+
args,
|
| 187 |
+
phase="flow",
|
| 188 |
+
data_type=args.data_type,
|
| 189 |
+
run_id=wandb.run.id,
|
| 190 |
+
datamodule=datamodule,
|
| 191 |
+
)
|
| 192 |
+
if args.data_type == "lidar":
|
| 193 |
+
FlowNetTrain = FlowNetTrainLidar
|
| 194 |
+
else:
|
| 195 |
+
FlowNetTrain = FlowNetTrainCell
|
| 196 |
+
|
| 197 |
+
flow_train = FlowNetTrain(
|
| 198 |
+
flow_matcher=flow_matcher_base,
|
| 199 |
+
flow_nets=flow_nets,
|
| 200 |
+
ot_sampler=ot_sampler,
|
| 201 |
+
skipped_time_points=skipped_time_points,
|
| 202 |
+
args=args,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
wandb_logger = WandbLogger()
|
| 206 |
+
|
| 207 |
+
trainer = Trainer(
|
| 208 |
+
max_epochs=args.epochs,
|
| 209 |
+
callbacks=flow_callbacks,
|
| 210 |
+
check_val_every_n_epoch=args.check_val_every_n_epoch,
|
| 211 |
+
accelerator=args.accelerator,
|
| 212 |
+
logger=wandb_logger,
|
| 213 |
+
default_root_dir=args.working_dir,
|
| 214 |
+
gradient_clip_val=(1.0 if args.data_type == "image" else None),
|
| 215 |
+
num_sanity_val_steps=(0 if args.data_type == "image" else None),
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
trainer.fit(
|
| 219 |
+
flow_train, datamodule=datamodule, ckpt_path=args.resume_flow_model_ckpt
|
| 220 |
+
)
|
| 221 |
+
if args.data_type == "lidar":
|
| 222 |
+
trainer.test(flow_train, datamodule=datamodule)
|
| 223 |
+
##### STAGE 2: Flow Matching End #####
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
##### STAGE 3: Training Growth Networks Beginning ####
|
| 227 |
+
flow_nets = flow_train.flow_nets
|
| 228 |
+
|
| 229 |
+
growth_callbacks = create_callbacks(
|
| 230 |
+
args,
|
| 231 |
+
phase="growth",
|
| 232 |
+
data_type=args.data_type,
|
| 233 |
+
run_id=wandb.run.id,
|
| 234 |
+
datamodule=datamodule,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
if args.data_type == "lidar":
|
| 238 |
+
GrowthNetTrain = GrowthNetTrainLidar
|
| 239 |
+
else:
|
| 240 |
+
GrowthNetTrain = GrowthNetTrainCell
|
| 241 |
+
|
| 242 |
+
growth_train = GrowthNetTrain(
|
| 243 |
+
flow_nets = flow_nets,
|
| 244 |
+
growth_nets = growth_nets,
|
| 245 |
+
ot_sampler=ot_sampler,
|
| 246 |
+
skipped_time_points=skipped_time_points,
|
| 247 |
+
args=args,
|
| 248 |
+
data_manifold_metric=data_manifold_metric,
|
| 249 |
+
joint = False
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
wandb_logger = WandbLogger()
|
| 253 |
+
|
| 254 |
+
trainer = Trainer(
|
| 255 |
+
max_epochs=args.epochs,
|
| 256 |
+
callbacks=growth_callbacks,
|
| 257 |
+
check_val_every_n_epoch=args.check_val_every_n_epoch,
|
| 258 |
+
accelerator=args.accelerator,
|
| 259 |
+
logger=wandb_logger,
|
| 260 |
+
default_root_dir=args.working_dir,
|
| 261 |
+
gradient_clip_val=(1.0 if args.data_type == "image" else None),
|
| 262 |
+
num_sanity_val_steps=(0 if args.data_type == "image" else None),
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
trainer.fit(
|
| 266 |
+
growth_train, datamodule=datamodule, ckpt_path=None
|
| 267 |
+
)
|
| 268 |
+
trainer.test(growth_train, datamodule=datamodule)
|
| 269 |
+
|
| 270 |
+
##### STAGE 3: Training Growth Networks End ####
|
| 271 |
+
|
| 272 |
+
##### STAGE 4: Joint Training Beginning ####
|
| 273 |
+
|
| 274 |
+
growth_nets = growth_train.growth_nets
|
| 275 |
+
|
| 276 |
+
joint_callbacks = create_callbacks(
|
| 277 |
+
args,
|
| 278 |
+
phase="joint",
|
| 279 |
+
data_type=args.data_type,
|
| 280 |
+
run_id=wandb.run.id,
|
| 281 |
+
datamodule=datamodule,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
if args.data_type == "lidar":
|
| 285 |
+
GrowthNetTrain = GrowthNetTrainLidar
|
| 286 |
+
else:
|
| 287 |
+
GrowthNetTrain = GrowthNetTrainCell
|
| 288 |
+
|
| 289 |
+
joint_train = GrowthNetTrain(
|
| 290 |
+
flow_nets = flow_nets,
|
| 291 |
+
growth_nets = growth_nets,
|
| 292 |
+
ot_sampler=ot_sampler,
|
| 293 |
+
skipped_time_points=skipped_time_points,
|
| 294 |
+
args=args,
|
| 295 |
+
data_manifold_metric=data_manifold_metric,
|
| 296 |
+
joint = True
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
wandb_logger = WandbLogger()
|
| 300 |
+
|
| 301 |
+
trainer = Trainer(
|
| 302 |
+
max_epochs=args.epochs,
|
| 303 |
+
callbacks=joint_callbacks,
|
| 304 |
+
check_val_every_n_epoch=args.check_val_every_n_epoch,
|
| 305 |
+
accelerator=args.accelerator,
|
| 306 |
+
logger=wandb_logger,
|
| 307 |
+
default_root_dir=args.working_dir,
|
| 308 |
+
gradient_clip_val=(1.0 if args.data_type == "image" else None),
|
| 309 |
+
num_sanity_val_steps=(0 if args.data_type == "image" else None),
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
trainer.fit(
|
| 313 |
+
joint_train, datamodule=datamodule, ckpt_path=None
|
| 314 |
+
)
|
| 315 |
+
trainer.test(joint_train, datamodule=datamodule)
|
| 316 |
+
|
| 317 |
+
##### STAGE 4: Joint Training End ####
|
| 318 |
+
|
| 319 |
+
wandb.finish()
|
| 320 |
+
|
| 321 |
+
if __name__ == "__main__":
|
| 322 |
+
args = parse_args()
|
| 323 |
+
updated_args = copy.deepcopy(args)
|
| 324 |
+
if args.config_path:
|
| 325 |
+
config = load_config(args.config_path)
|
| 326 |
+
updated_args = merge_config(updated_args, config)
|
| 327 |
+
|
| 328 |
+
updated_args.group_name = generate_group_string()
|
| 329 |
+
updated_args.data_path = dataset_name2datapath(
|
| 330 |
+
updated_args.data_name, updated_args.working_dir
|
| 331 |
+
)
|
| 332 |
+
for seed in updated_args.seeds:
|
| 333 |
+
if updated_args.t_exclude:
|
| 334 |
+
for i, t_exclude in enumerate(updated_args.t_exclude):
|
| 335 |
+
updated_args.t_exclude_current = t_exclude
|
| 336 |
+
updated_args.seed_current = seed
|
| 337 |
+
updated_args.gamma_current = updated_args.gammas[i]
|
| 338 |
+
main(updated_args, seed=seed, t_exclude=t_exclude)
|
| 339 |
+
else:
|
| 340 |
+
updated_args.seed_current = seed
|
| 341 |
+
updated_args.gamma_current = updated_args.gammas[0]
|
| 342 |
+
main(updated_args, seed=seed, t_exclude=None)
|
train/parsers.py
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
def parse_args():
|
| 4 |
+
parser = argparse.ArgumentParser(description="Train BranchSBM")
|
| 5 |
+
|
| 6 |
+
parser.add_argument(
|
| 7 |
+
"--config_path", type=str,
|
| 8 |
+
default='./configs/experiment/lidar.yaml',
|
| 9 |
+
help="Path to config file"
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
####### ITERATES IN THE CODE #######
|
| 13 |
+
parser.add_argument(
|
| 14 |
+
"--seeds",
|
| 15 |
+
nargs="+",
|
| 16 |
+
type=int,
|
| 17 |
+
default=[42, 43, 44, 45, 46],
|
| 18 |
+
help="Random seeds to iterate over",
|
| 19 |
+
)
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
"--t_exclude",
|
| 22 |
+
nargs="+",
|
| 23 |
+
type=int,
|
| 24 |
+
default=[1, 2],
|
| 25 |
+
help="Time points to exclude (iterating over)",
|
| 26 |
+
)
|
| 27 |
+
####################################
|
| 28 |
+
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"--working_dir",
|
| 31 |
+
type=str,
|
| 32 |
+
default="./",
|
| 33 |
+
help="Working directory",
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--resume_flow_model_ckpt",
|
| 37 |
+
type=str,
|
| 38 |
+
default=None,
|
| 39 |
+
help="Path to the flow model to resume training",
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--resume_growth_model_ckpt",
|
| 43 |
+
type=str,
|
| 44 |
+
default=None,
|
| 45 |
+
help="Path to the flow model to resume training",
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--load_geopath_model_ckpt",
|
| 49 |
+
type=str,
|
| 50 |
+
default=None,
|
| 51 |
+
help="Path to the geopath model to resume training",
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--branches",
|
| 55 |
+
type=int,
|
| 56 |
+
default=2,
|
| 57 |
+
help="Number of branches",
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--metric_clusters",
|
| 61 |
+
type=int,
|
| 62 |
+
default=3,
|
| 63 |
+
help="Number of metric clusters",
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
######### DATASETS #################
|
| 67 |
+
parser = datasets_parser(parser)
|
| 68 |
+
####################################
|
| 69 |
+
|
| 70 |
+
######### IMAGE DATASETS ###########
|
| 71 |
+
parser = image_datasets_parser(parser)
|
| 72 |
+
####################################
|
| 73 |
+
|
| 74 |
+
######### METRICS ##################
|
| 75 |
+
parser = metric_parser(parser)
|
| 76 |
+
####################################
|
| 77 |
+
|
| 78 |
+
######### General Training #########
|
| 79 |
+
parser = general_training_parser(parser)
|
| 80 |
+
####################################
|
| 81 |
+
|
| 82 |
+
######### Training GeoPath Network ####
|
| 83 |
+
parser = geopath_network_parser(parser)
|
| 84 |
+
####################################
|
| 85 |
+
|
| 86 |
+
######### Training Flow Network ####
|
| 87 |
+
parser = flow_network_parser(parser)
|
| 88 |
+
####################################
|
| 89 |
+
|
| 90 |
+
parser = growth_network_parser(parser)
|
| 91 |
+
|
| 92 |
+
return parser.parse_args()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def datasets_parser(parser):
|
| 96 |
+
parser.add_argument("--dim", type=int, default=3, help="Dimension of data")
|
| 97 |
+
|
| 98 |
+
parser.add_argument(
|
| 99 |
+
"--data_type",
|
| 100 |
+
type=str,
|
| 101 |
+
default="lidar",
|
| 102 |
+
help="Type of data, now wither scrna or one of toys",
|
| 103 |
+
)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--data_path",
|
| 106 |
+
type=str,
|
| 107 |
+
default="./data/rainier2-thin.las",
|
| 108 |
+
help="lidar data path",
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--data_name",
|
| 112 |
+
type=str,
|
| 113 |
+
default="lidar",
|
| 114 |
+
help="Path to the dataset",
|
| 115 |
+
)
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--whiten",
|
| 118 |
+
action=argparse.BooleanOptionalAction,
|
| 119 |
+
default=True,
|
| 120 |
+
help="Whiten the data",
|
| 121 |
+
)
|
| 122 |
+
return parser
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def image_datasets_parser(parser):
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--image_size",
|
| 128 |
+
type=int,
|
| 129 |
+
default=128,
|
| 130 |
+
help="Size of the image",
|
| 131 |
+
)
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--x0_label",
|
| 134 |
+
type=str,
|
| 135 |
+
default="dog",
|
| 136 |
+
help="Label for x0",
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--x1_label",
|
| 140 |
+
type=str,
|
| 141 |
+
default="cat",
|
| 142 |
+
help="Label for x1",
|
| 143 |
+
)
|
| 144 |
+
return parser
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def metric_parser(parser):
|
| 148 |
+
parser.add_argument(
|
| 149 |
+
"--branchsbm",
|
| 150 |
+
action=argparse.BooleanOptionalAction,
|
| 151 |
+
default=True,
|
| 152 |
+
help="If branched SBM",
|
| 153 |
+
)
|
| 154 |
+
parser.add_argument(
|
| 155 |
+
"--n_centers",
|
| 156 |
+
type=int,
|
| 157 |
+
default=100,
|
| 158 |
+
help="Number of centers for RBF network",
|
| 159 |
+
)
|
| 160 |
+
parser.add_argument(
|
| 161 |
+
"--kappa",
|
| 162 |
+
type=float,
|
| 163 |
+
default=1.0,
|
| 164 |
+
help="Kappa parameter for RBF network",
|
| 165 |
+
)
|
| 166 |
+
parser.add_argument(
|
| 167 |
+
"--rho",
|
| 168 |
+
type=float,
|
| 169 |
+
default=0.001,
|
| 170 |
+
help="Rho parameter in Riemanian Velocity Calculation",
|
| 171 |
+
)
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"--velocity_metric",
|
| 174 |
+
type=str,
|
| 175 |
+
default="rbf",
|
| 176 |
+
help="Metric for velocity calculation",
|
| 177 |
+
)
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"--gammas",
|
| 180 |
+
nargs="+",
|
| 181 |
+
type=float,
|
| 182 |
+
default=[0.2, 0.2],
|
| 183 |
+
help="Gamma parameter in Riemanian Velocity Calculation",
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--metric_epochs",
|
| 187 |
+
type=int,
|
| 188 |
+
default=50,
|
| 189 |
+
help="Number of epochs for metric learning",
|
| 190 |
+
)
|
| 191 |
+
parser.add_argument(
|
| 192 |
+
"--metric_patience",
|
| 193 |
+
type=int,
|
| 194 |
+
default=5,
|
| 195 |
+
help="Patience for metric learning",
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--metric_lr",
|
| 199 |
+
type=float,
|
| 200 |
+
default=1e-2,
|
| 201 |
+
help="Learning rate for metric learning",
|
| 202 |
+
)
|
| 203 |
+
parser.add_argument(
|
| 204 |
+
"--alpha_metric",
|
| 205 |
+
type=float,
|
| 206 |
+
default=1.0,
|
| 207 |
+
help="Alpha parameter for metric learning",
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
return parser
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def general_training_parser(parser):
|
| 214 |
+
parser.add_argument(
|
| 215 |
+
"--batch_size", type=int, default=128, help="Batch size for training"
|
| 216 |
+
)
|
| 217 |
+
parser.add_argument(
|
| 218 |
+
"--optimal_transport_method",
|
| 219 |
+
type=str,
|
| 220 |
+
default="exact",
|
| 221 |
+
help="Use optimal transport in CFM training",
|
| 222 |
+
)
|
| 223 |
+
parser.add_argument(
|
| 224 |
+
"--ema_decay",
|
| 225 |
+
type=float,
|
| 226 |
+
default=None,
|
| 227 |
+
help="Decay for EMA",
|
| 228 |
+
)
|
| 229 |
+
parser.add_argument(
|
| 230 |
+
"--split_ratios",
|
| 231 |
+
nargs=2,
|
| 232 |
+
type=float,
|
| 233 |
+
default=[0.9, 0.1],
|
| 234 |
+
help="Split ratios for training/validation data in CFM training",
|
| 235 |
+
)
|
| 236 |
+
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
|
| 237 |
+
parser.add_argument(
|
| 238 |
+
"--accelerator", type=str, default="cpu", help="Training accelerator"
|
| 239 |
+
)
|
| 240 |
+
parser.add_argument(
|
| 241 |
+
"--sim_num_steps",
|
| 242 |
+
type=int,
|
| 243 |
+
default=1000,
|
| 244 |
+
help="Number of steps in simulation",
|
| 245 |
+
)
|
| 246 |
+
return parser
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def geopath_network_parser(parser):
|
| 250 |
+
parser.add_argument(
|
| 251 |
+
"--manifold",
|
| 252 |
+
action=argparse.BooleanOptionalAction,
|
| 253 |
+
default=True,
|
| 254 |
+
help="If use data manifold metric",
|
| 255 |
+
)
|
| 256 |
+
parser.add_argument(
|
| 257 |
+
"--patience_geopath",
|
| 258 |
+
type=int,
|
| 259 |
+
default=50,
|
| 260 |
+
help="Patience for training geopath model",
|
| 261 |
+
)
|
| 262 |
+
parser.add_argument(
|
| 263 |
+
"--hidden_dims_geopath",
|
| 264 |
+
nargs="+",
|
| 265 |
+
type=int,
|
| 266 |
+
default=[64, 64, 64],
|
| 267 |
+
help="Dimensions of hidden layers for GeoPath model training",
|
| 268 |
+
)
|
| 269 |
+
parser.add_argument(
|
| 270 |
+
"--time_geopath",
|
| 271 |
+
action=argparse.BooleanOptionalAction,
|
| 272 |
+
default=False,
|
| 273 |
+
help="Use time in GeoPath model",
|
| 274 |
+
)
|
| 275 |
+
parser.add_argument(
|
| 276 |
+
"--activation_geopath",
|
| 277 |
+
type=str,
|
| 278 |
+
default="selu",
|
| 279 |
+
help="Activation function for GeoPath",
|
| 280 |
+
)
|
| 281 |
+
parser.add_argument(
|
| 282 |
+
"--geopath_optimizer",
|
| 283 |
+
type=str,
|
| 284 |
+
default="adam",
|
| 285 |
+
help="Optimizer for GeoPath training",
|
| 286 |
+
)
|
| 287 |
+
parser.add_argument(
|
| 288 |
+
"--geopath_lr",
|
| 289 |
+
type=float,
|
| 290 |
+
default=1e-4,
|
| 291 |
+
help="Learning rate for GeoPath training",
|
| 292 |
+
)
|
| 293 |
+
parser.add_argument(
|
| 294 |
+
"--geopath_weight_decay",
|
| 295 |
+
type=float,
|
| 296 |
+
default=1e-5,
|
| 297 |
+
help="Weight decay for GeoPath training",
|
| 298 |
+
)
|
| 299 |
+
return parser
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def flow_network_parser(parser):
|
| 303 |
+
parser.add_argument(
|
| 304 |
+
"--sigma", type=float, default=0.1, help="Sigma parameter for CFM (variance)"
|
| 305 |
+
)
|
| 306 |
+
parser.add_argument(
|
| 307 |
+
"--patience",
|
| 308 |
+
type=int,
|
| 309 |
+
default=5,
|
| 310 |
+
help="Patience for early stopping in CFM training",
|
| 311 |
+
)
|
| 312 |
+
parser.add_argument(
|
| 313 |
+
"--hidden_dims_flow",
|
| 314 |
+
nargs="+",
|
| 315 |
+
type=int,
|
| 316 |
+
default=[64, 64, 64],
|
| 317 |
+
help="Dimensions of hidden layers for CFM training",
|
| 318 |
+
)
|
| 319 |
+
parser.add_argument(
|
| 320 |
+
"--check_val_every_n_epoch",
|
| 321 |
+
type=int,
|
| 322 |
+
default=10,
|
| 323 |
+
help="Check validation every N epochs during CFM training",
|
| 324 |
+
)
|
| 325 |
+
parser.add_argument(
|
| 326 |
+
"--activation_flow",
|
| 327 |
+
type=str,
|
| 328 |
+
default="selu",
|
| 329 |
+
help="Activation function for CFM",
|
| 330 |
+
)
|
| 331 |
+
parser.add_argument(
|
| 332 |
+
"--flow_optimizer",
|
| 333 |
+
type=str,
|
| 334 |
+
default="adamw",
|
| 335 |
+
help="Optimizer for GeoPath training",
|
| 336 |
+
)
|
| 337 |
+
parser.add_argument(
|
| 338 |
+
"--flow_lr",
|
| 339 |
+
type=float,
|
| 340 |
+
default=1e-3,
|
| 341 |
+
help="Learning rate for GeoPath training",
|
| 342 |
+
)
|
| 343 |
+
parser.add_argument(
|
| 344 |
+
"--flow_weight_decay",
|
| 345 |
+
type=float,
|
| 346 |
+
default=1e-5,
|
| 347 |
+
help="Weight decay for GeoPath training",
|
| 348 |
+
)
|
| 349 |
+
return parser
|
| 350 |
+
|
| 351 |
+
def growth_network_parser(parser):
|
| 352 |
+
parser.add_argument(
|
| 353 |
+
"--patience_growth",
|
| 354 |
+
type=int,
|
| 355 |
+
default=5,
|
| 356 |
+
help="Patience for early stopping in CFM training",
|
| 357 |
+
)
|
| 358 |
+
parser.add_argument(
|
| 359 |
+
"--time_growth",
|
| 360 |
+
action=argparse.BooleanOptionalAction,
|
| 361 |
+
default=False,
|
| 362 |
+
help="Use time in GeoPath model",
|
| 363 |
+
)
|
| 364 |
+
parser.add_argument(
|
| 365 |
+
"--hidden_dims_growth",
|
| 366 |
+
nargs="+",
|
| 367 |
+
type=int,
|
| 368 |
+
default=[64, 64, 64],
|
| 369 |
+
help="Dimensions of hidden layers for growth net training",
|
| 370 |
+
)
|
| 371 |
+
parser.add_argument(
|
| 372 |
+
"--activation_growth",
|
| 373 |
+
type=str,
|
| 374 |
+
default="tanh",
|
| 375 |
+
help="Activation function for CFM",
|
| 376 |
+
)
|
| 377 |
+
parser.add_argument(
|
| 378 |
+
"--growth_optimizer",
|
| 379 |
+
type=str,
|
| 380 |
+
default="adamw",
|
| 381 |
+
help="Optimizer for GeoPath training",
|
| 382 |
+
)
|
| 383 |
+
parser.add_argument(
|
| 384 |
+
"--growth_lr",
|
| 385 |
+
type=float,
|
| 386 |
+
default=1e-3,
|
| 387 |
+
help="Learning rate for GeoPath training",
|
| 388 |
+
)
|
| 389 |
+
parser.add_argument(
|
| 390 |
+
"--growth_weight_decay",
|
| 391 |
+
type=float,
|
| 392 |
+
default=1e-5,
|
| 393 |
+
help="Weight decay for GeoPath training",
|
| 394 |
+
)
|
| 395 |
+
parser.add_argument(
|
| 396 |
+
"--lambda_energy",
|
| 397 |
+
type=float,
|
| 398 |
+
default=1.0,
|
| 399 |
+
help="Weight for energy loss",
|
| 400 |
+
)
|
| 401 |
+
parser.add_argument(
|
| 402 |
+
"--lambda_mass",
|
| 403 |
+
type=float,
|
| 404 |
+
default=100.0,
|
| 405 |
+
help="Weight for mass loss",
|
| 406 |
+
)
|
| 407 |
+
parser.add_argument(
|
| 408 |
+
"--lambda_match",
|
| 409 |
+
type=float,
|
| 410 |
+
default=1000.0,
|
| 411 |
+
help="Weight for matching loss",
|
| 412 |
+
)
|
| 413 |
+
parser.add_argument(
|
| 414 |
+
"--lambda_recons",
|
| 415 |
+
type=float,
|
| 416 |
+
default=1.0,
|
| 417 |
+
help="Weight for reconstruction loss",
|
| 418 |
+
)
|
| 419 |
+
return parser
|
train/train_utils.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append("./BranchSBM")
|
| 3 |
+
import yaml
|
| 4 |
+
import string
|
| 5 |
+
import secrets
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
import wandb
|
| 9 |
+
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
|
| 10 |
+
from torchdyn.core import NeuralODE
|
| 11 |
+
from utils import plot_images_trajectory
|
| 12 |
+
from networks.utils import flow_model_torch_wrapper
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_config(path):
|
| 16 |
+
with open(path, "r") as file:
|
| 17 |
+
config = yaml.safe_load(file)
|
| 18 |
+
return config
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def merge_config(args, config_updates):
|
| 22 |
+
for key, value in config_updates.items():
|
| 23 |
+
if not hasattr(args, key):
|
| 24 |
+
raise ValueError(
|
| 25 |
+
f"Unknown configuration parameter '{key}' found in the config file."
|
| 26 |
+
)
|
| 27 |
+
setattr(args, key, value)
|
| 28 |
+
return args
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def generate_group_string(length=16):
|
| 32 |
+
alphabet = string.ascii_letters + string.digits
|
| 33 |
+
return "".join(secrets.choice(alphabet) for _ in range(length))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def dataset_name2datapath(dataset_name, working_dir):
|
| 37 |
+
if dataset_name in ["lidar", "lidarsingle"]:
|
| 38 |
+
return os.path.join(working_dir, "/raid/st512/branchsbm/data", "rainier2-thin.las")
|
| 39 |
+
elif dataset_name == "mouse":
|
| 40 |
+
return os.path.join(working_dir, "/raid/st512/branchsbm/data", "mouse_hematopoiesis.csv")
|
| 41 |
+
elif dataset_name in ["clonidine50D", "clonidine100D", "clonidine150D", "clonidine50Dsingle", "clonidine100Dsingle", "clonidine150Dsingle"]:
|
| 42 |
+
return os.path.join(working_dir, "/raid/st512/branchsbm/data", "pca_and_leiden_labels.csv")
|
| 43 |
+
elif dataset_name in ["trametinib", "trametinibsingle"]:
|
| 44 |
+
return os.path.join(working_dir, "/raid/st512/branchsbm/data", "Trametinib_5.0uM_pca_and_leidenumap_labels.csv")
|
| 45 |
+
else:
|
| 46 |
+
raise ValueError("Dataset not recognized")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def create_callbacks(args, phase, data_type, run_id, datamodule=None):
|
| 50 |
+
|
| 51 |
+
dirpath = os.path.join(
|
| 52 |
+
args.working_dir,
|
| 53 |
+
"checkpoints",
|
| 54 |
+
data_type,
|
| 55 |
+
str(run_id),
|
| 56 |
+
f"{phase}_model",
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
if phase == "geopath":
|
| 60 |
+
early_stop_callback = EarlyStopping(
|
| 61 |
+
monitor="BranchPathNet/val_loss_geopath",
|
| 62 |
+
patience=args.patience_geopath,
|
| 63 |
+
mode="min",
|
| 64 |
+
)
|
| 65 |
+
checkpoint_callback = ModelCheckpoint(
|
| 66 |
+
dirpath=dirpath,
|
| 67 |
+
monitor="BranchPathNet/val_loss_geopath",
|
| 68 |
+
mode="min",
|
| 69 |
+
save_top_k=1,
|
| 70 |
+
)
|
| 71 |
+
callbacks = [checkpoint_callback, early_stop_callback]
|
| 72 |
+
elif phase == "flow":
|
| 73 |
+
early_stop_callback = EarlyStopping(
|
| 74 |
+
monitor="FlowNet/val_loss_cfm",
|
| 75 |
+
patience=args.patience,
|
| 76 |
+
mode="min",
|
| 77 |
+
)
|
| 78 |
+
checkpoint_callback = ModelCheckpoint(
|
| 79 |
+
dirpath=dirpath,
|
| 80 |
+
mode="min",
|
| 81 |
+
save_top_k=1,
|
| 82 |
+
)
|
| 83 |
+
callbacks = [checkpoint_callback, early_stop_callback]
|
| 84 |
+
elif phase == "growth":
|
| 85 |
+
early_stop_callback = EarlyStopping(
|
| 86 |
+
monitor="GrowthNet/val_loss",
|
| 87 |
+
patience=args.patience,
|
| 88 |
+
mode="min",
|
| 89 |
+
)
|
| 90 |
+
checkpoint_callback = ModelCheckpoint(
|
| 91 |
+
dirpath=dirpath,
|
| 92 |
+
mode="min",
|
| 93 |
+
save_top_k=1,
|
| 94 |
+
)
|
| 95 |
+
callbacks = [checkpoint_callback, early_stop_callback]
|
| 96 |
+
elif phase == "joint":
|
| 97 |
+
early_stop_callback = EarlyStopping(
|
| 98 |
+
monitor="JointTrain/val_loss",
|
| 99 |
+
patience=args.patience,
|
| 100 |
+
mode="min",
|
| 101 |
+
)
|
| 102 |
+
checkpoint_callback = ModelCheckpoint(
|
| 103 |
+
dirpath=dirpath,
|
| 104 |
+
mode="min",
|
| 105 |
+
save_top_k=1,
|
| 106 |
+
)
|
| 107 |
+
callbacks = [checkpoint_callback, early_stop_callback]
|
| 108 |
+
else:
|
| 109 |
+
raise ValueError("Unknown phase")
|
| 110 |
+
return callbacks
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class PlottingCallback(Callback):
|
| 114 |
+
def __init__(self, plot_interval, datamodule):
|
| 115 |
+
self.plot_interval = plot_interval
|
| 116 |
+
self.datamodule = datamodule
|
| 117 |
+
|
| 118 |
+
def on_train_epoch_end(self, trainer, pl_module):
|
| 119 |
+
epoch = trainer.current_epoch
|
| 120 |
+
pl_module.flow_net.train(mode=False)
|
| 121 |
+
if epoch % self.plot_interval == 0 and epoch != 0:
|
| 122 |
+
node = NeuralODE(
|
| 123 |
+
flow_model_torch_wrapper(pl_module.flow_net).to(self.datamodule.device),
|
| 124 |
+
solver="tsit5",
|
| 125 |
+
sensitivity="adjoint",
|
| 126 |
+
atol=1e-5,
|
| 127 |
+
rtol=1e-5,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
for mode in ["train", "val"]:
|
| 131 |
+
x0 = getattr(self.datamodule, f"{mode}_x0")
|
| 132 |
+
x0 = x0[0:15]
|
| 133 |
+
fig = self.trajectory_and_plot(x0, node, self.datamodule)
|
| 134 |
+
wandb.log({f"Trajectories {mode.capitalize()}": wandb.Image(fig)})
|
| 135 |
+
pl_module.flow_net.train(mode=True)
|
| 136 |
+
|
| 137 |
+
def trajectory_and_plot(self, x0, node, datamodule):
|
| 138 |
+
selected_images = x0[0:15]
|
| 139 |
+
with torch.no_grad():
|
| 140 |
+
traj = node.trajectory(
|
| 141 |
+
selected_images.to(datamodule.device),
|
| 142 |
+
t_span=torch.linspace(0, 1, 100).to(datamodule.device),
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
traj = traj.transpose(0, 1)
|
| 146 |
+
traj = traj.reshape(*traj.shape[0:2], *datamodule.dim)
|
| 147 |
+
|
| 148 |
+
fig = plot_images_trajectory(
|
| 149 |
+
traj.to(datamodule.device),
|
| 150 |
+
datamodule.vae.to(datamodule.device),
|
| 151 |
+
datamodule.process,
|
| 152 |
+
num_steps=5,
|
| 153 |
+
)
|
| 154 |
+
return fig
|
utils.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
import matplotlib
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import math
|
| 7 |
+
import umap
|
| 8 |
+
import scanpy as sc
|
| 9 |
+
from sklearn.decomposition import PCA
|
| 10 |
+
|
| 11 |
+
import ot as pot
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from functools import partial
|
| 14 |
+
from typing import Optional
|
| 15 |
+
|
| 16 |
+
from matplotlib.colors import LinearSegmentedColormap
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def set_seed(seed):
|
| 20 |
+
"""
|
| 21 |
+
Sets the seed for reproducibility in PyTorch, Numpy, and Python's Random.
|
| 22 |
+
|
| 23 |
+
Parameters:
|
| 24 |
+
seed (int): The seed for the random number generators.
|
| 25 |
+
"""
|
| 26 |
+
random.seed(seed) # Python random module
|
| 27 |
+
np.random.seed(seed) # Numpy
|
| 28 |
+
torch.manual_seed(seed) # CPU and GPU (deterministic)
|
| 29 |
+
if torch.cuda.is_available():
|
| 30 |
+
torch.cuda.manual_seed(seed) # CUDA
|
| 31 |
+
torch.cuda.manual_seed_all(seed) # all GPU devices
|
| 32 |
+
torch.backends.cudnn.deterministic = True # CuDNN behavior
|
| 33 |
+
torch.backends.cudnn.benchmark = False
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def wasserstein_distance(
|
| 37 |
+
x0: torch.Tensor,
|
| 38 |
+
x1: torch.Tensor,
|
| 39 |
+
method: Optional[str] = None,
|
| 40 |
+
reg: float = 0.05,
|
| 41 |
+
power: int = 1,
|
| 42 |
+
**kwargs,
|
| 43 |
+
) -> float:
|
| 44 |
+
assert power == 1 or power == 2
|
| 45 |
+
if method == "exact" or method is None:
|
| 46 |
+
ot_fn = pot.emd2
|
| 47 |
+
elif method == "sinkhorn":
|
| 48 |
+
ot_fn = partial(pot.sinkhorn2, reg=reg)
|
| 49 |
+
else:
|
| 50 |
+
raise ValueError(f"Unknown method: {method}")
|
| 51 |
+
|
| 52 |
+
a, b = pot.unif(x0.shape[0]), pot.unif(x1.shape[0])
|
| 53 |
+
if x0.dim() > 2:
|
| 54 |
+
x0 = x0.reshape(x0.shape[0], -1)
|
| 55 |
+
if x1.dim() > 2:
|
| 56 |
+
x1 = x1.reshape(x1.shape[0], -1)
|
| 57 |
+
M = torch.cdist(x0, x1)
|
| 58 |
+
if power == 2:
|
| 59 |
+
M = M**2
|
| 60 |
+
ret = ot_fn(a, b, M.detach().cpu().numpy(), numItermax=1e7)
|
| 61 |
+
if power == 2:
|
| 62 |
+
ret = math.sqrt(ret)
|
| 63 |
+
return ret
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def plot_lidar(ax, dataset, xs=None, S=25, branch_idx=None):
|
| 67 |
+
# Combine the dataset and trajectory points for sorting
|
| 68 |
+
combined_points = []
|
| 69 |
+
combined_colors = []
|
| 70 |
+
combined_sizes = []
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
custom_colors_1 = ["#05009E", "#A19EFF", "#50B2D7"]
|
| 74 |
+
custom_colors_2 = ["#05009E", "#A19EFF", "#D577FF"]
|
| 75 |
+
|
| 76 |
+
custom_cmap_1 = LinearSegmentedColormap.from_list("my_cmap", custom_colors_1)
|
| 77 |
+
custom_cmap_2 = LinearSegmentedColormap.from_list("my_cmap", custom_colors_2)
|
| 78 |
+
|
| 79 |
+
# Normalize the z-coordinates for alpha scaling
|
| 80 |
+
z_coords = (
|
| 81 |
+
dataset[:, 2].numpy() if torch.is_tensor(dataset[:, 2]) else dataset[:, 2]
|
| 82 |
+
)
|
| 83 |
+
z_min, z_max = z_coords.min(), z_coords.max()
|
| 84 |
+
z_norm = (z_coords - z_min) / (z_max - z_min)
|
| 85 |
+
|
| 86 |
+
# Add surface points with a lower z-order
|
| 87 |
+
for i, point in enumerate(dataset):
|
| 88 |
+
grey_value = 0.95 - 0.7 * z_norm[i]
|
| 89 |
+
combined_points.append(point.numpy())
|
| 90 |
+
combined_colors.append(
|
| 91 |
+
(
|
| 92 |
+
grey_value,
|
| 93 |
+
grey_value,
|
| 94 |
+
grey_value,
|
| 95 |
+
1.0
|
| 96 |
+
)
|
| 97 |
+
) # Grey color with transparency
|
| 98 |
+
combined_sizes.append(0.1)
|
| 99 |
+
|
| 100 |
+
# Add trajectory points with a higher z-order
|
| 101 |
+
if xs is not None:
|
| 102 |
+
if branch_idx == 0:
|
| 103 |
+
cmap = custom_cmap_1
|
| 104 |
+
else:
|
| 105 |
+
cmap = custom_cmap_2
|
| 106 |
+
|
| 107 |
+
B, T, D = xs.shape
|
| 108 |
+
steps_to_log = np.linspace(0, T - 1, S).astype(int)
|
| 109 |
+
xs = xs.cpu().detach().clone()
|
| 110 |
+
for idx, step in enumerate(steps_to_log):
|
| 111 |
+
for point in xs[:512, step]:
|
| 112 |
+
combined_points.append(
|
| 113 |
+
point.numpy() if torch.is_tensor(point) else point
|
| 114 |
+
)
|
| 115 |
+
combined_colors.append(cmap(idx / (len(steps_to_log) - 1)))
|
| 116 |
+
combined_sizes.append(0.8)
|
| 117 |
+
|
| 118 |
+
# Convert to numpy array for easier manipulation
|
| 119 |
+
combined_points = np.array(combined_points)
|
| 120 |
+
combined_colors = np.array(combined_colors)
|
| 121 |
+
combined_sizes = np.array(combined_sizes)
|
| 122 |
+
|
| 123 |
+
# Sort by z-coordinate (depth)
|
| 124 |
+
sorted_indices = np.argsort(combined_points[:, 2])
|
| 125 |
+
combined_points = combined_points[sorted_indices]
|
| 126 |
+
combined_colors = combined_colors[sorted_indices]
|
| 127 |
+
combined_sizes = combined_sizes[sorted_indices]
|
| 128 |
+
|
| 129 |
+
# Plot the sorted points
|
| 130 |
+
ax.scatter(
|
| 131 |
+
combined_points[:, 0],
|
| 132 |
+
combined_points[:, 1],
|
| 133 |
+
combined_points[:, 2],
|
| 134 |
+
s=combined_sizes,
|
| 135 |
+
c=combined_colors,
|
| 136 |
+
depthshade=True,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
ax.set_xlim3d(left=-4.8, right=4.8)
|
| 140 |
+
ax.set_ylim3d(bottom=-4.8, top=4.8)
|
| 141 |
+
ax.set_zlim3d(bottom=0.0, top=2.0)
|
| 142 |
+
ax.set_zticks([0, 1.0, 2.0])
|
| 143 |
+
ax.grid(False)
|
| 144 |
+
plt.axis("off")
|
| 145 |
+
|
| 146 |
+
return ax
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def plot_images_trajectory(trajectories, vae, processor, num_steps):
|
| 150 |
+
|
| 151 |
+
# Compute trajectories for each image
|
| 152 |
+
t_span = torch.linspace(0, trajectories.shape[1] - 1, num_steps)
|
| 153 |
+
t_span = [int(t) for t in t_span]
|
| 154 |
+
num_images = trajectories.shape[0]
|
| 155 |
+
|
| 156 |
+
# Decode images at each step in each trajectory
|
| 157 |
+
decoded_images = [
|
| 158 |
+
[
|
| 159 |
+
processor.postprocess(
|
| 160 |
+
vae.decode(
|
| 161 |
+
trajectories[i_image, traj_step].unsqueeze(0)
|
| 162 |
+
).sample.detach()
|
| 163 |
+
)[0]
|
| 164 |
+
for traj_step in t_span
|
| 165 |
+
]
|
| 166 |
+
for i_image in range(num_images)
|
| 167 |
+
]
|
| 168 |
+
|
| 169 |
+
# Plotting
|
| 170 |
+
fig, axes = plt.subplots(
|
| 171 |
+
num_images, num_steps, figsize=(num_steps * 2, num_images * 2)
|
| 172 |
+
)
|
| 173 |
+
if num_images == 1:
|
| 174 |
+
axes = [axes] # Ensure axes is iterable
|
| 175 |
+
for img_idx, img_traj in enumerate(decoded_images):
|
| 176 |
+
for step_idx, img in enumerate(img_traj):
|
| 177 |
+
ax = axes[img_idx][step_idx] if num_images > 1 else axes[step_idx]
|
| 178 |
+
if (
|
| 179 |
+
isinstance(img, np.ndarray) and img.shape[0] == 3
|
| 180 |
+
): # Assuming 3 channels (RGB)
|
| 181 |
+
img = img.transpose(1, 2, 0)
|
| 182 |
+
ax.imshow(img)
|
| 183 |
+
ax.axis("off")
|
| 184 |
+
if img_idx == 0:
|
| 185 |
+
ax.set_title(f"t={t_span[step_idx]/t_span[-1]:.2f}")
|
| 186 |
+
plt.tight_layout()
|
| 187 |
+
return fig
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def plot_growth(dataset, growth_nets, xs, output_file='plot.pdf'):
|
| 191 |
+
x0s = [dataset["x0"][0]]
|
| 192 |
+
w0s = [dataset["x0"][1]]
|
| 193 |
+
x1s_list = [[dataset["x1_1"][0]], [dataset["x1_2"][0]]]
|
| 194 |
+
w1s_list = [[dataset["x1_1"][1]], [dataset["x1_2"][1]]]
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
plt.show()
|