darabos commited on
Commit
e7c9562
·
2 Parent(s): 663af8e cef1508

Merge remote-tracking branch 'origin/main' into darabos-open-source-merge

Browse files
lynxkite-graph-analytics/.dockerignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ lynxkite_data
2
+ lynxkite_crdt_data
3
+ .venv
lynxkite-graph-analytics/Dockerfile.bionemo ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvcr.io/nvidia/clara/bionemo-framework:nightly
2
+
3
+ ENV LYNXKITE_BIONEMO_INSTALLED=true
4
+
5
+ WORKDIR /app
6
+
7
+ # Download and install nvm
8
+ RUN curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.2/install.sh | bash
9
+ RUN echo node > .nvmrc
10
+ RUN source /root/.nvm/nvm.sh --install
11
+
12
+ COPY . /app
13
+
14
+ RUN uv pip install -e lynxkite-core/[dev] -e lynxkite-app/[dev] -e lynxkite-graph-analytics/[dev] -e lynxkite-bio -e lynxkite-pillow-example/
15
+
16
+ # bionemo cellxgene_census needs this version of numpy
17
+ RUN uv pip install numpy==1.26.4
18
+
19
+ ENV LYNXKITE_DATA=examples
20
+
21
+ CMD ["lynxkite"]
lynxkite-graph-analytics/src/lynxkite_graph_analytics/bionemo_ops.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """BioNeMo related operations
2
+
3
+ The intention is to showcase how BioNeMo can be integrated with LynxKite. This should be
4
+ considered as a reference implementation and not a production ready code.
5
+ The operations are quite specific for this example notebook:
6
+ https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/examples/bionemo-geneformer/geneformer-celltype-classification.ipynb
7
+ """
8
+
9
+ from lynxkite.core import ops
10
+ import requests
11
+ import tarfile
12
+ import os
13
+ from collections import Counter
14
+ from . import core
15
+ import numpy as np
16
+ import torch
17
+ from pathlib import Path
18
+ import random
19
+ from contextlib import contextmanager
20
+ import cellxgene_census # TODO: This needs numpy < 2
21
+ import tempfile
22
+ from sklearn.ensemble import RandomForestClassifier
23
+ from sklearn.pipeline import Pipeline
24
+ from sklearn.model_selection import StratifiedKFold, cross_validate
25
+ from sklearn.metrics import (
26
+ make_scorer,
27
+ accuracy_score,
28
+ precision_score,
29
+ recall_score,
30
+ f1_score,
31
+ roc_auc_score,
32
+ confusion_matrix,
33
+ )
34
+ from sklearn.decomposition import PCA
35
+ from sklearn.model_selection import cross_val_predict
36
+ from sklearn.preprocessing import LabelEncoder
37
+ from bionemo.scdl.io.single_cell_collection import SingleCellCollection
38
+
39
+ import scanpy
40
+
41
+
42
+ op = ops.op_registration(core.ENV)
43
+ DATA_PATH = Path("/workspace")
44
+
45
+
46
+ @contextmanager
47
+ def random_seed(seed: int):
48
+ state = random.getstate()
49
+ random.seed(seed)
50
+ try:
51
+ yield
52
+ finally:
53
+ # Go back to previous state
54
+ random.setstate(state)
55
+
56
+
57
+ @op("BioNeMo > Download CELLxGENE dataset", slow=True)
58
+ def download_cellxgene_dataset(
59
+ *,
60
+ save_path: str,
61
+ census_version: str = "2023-12-15",
62
+ organism: str = "Homo sapiens",
63
+ value_filter='dataset_id=="8e47ed12-c658-4252-b126-381df8d52a3d"',
64
+ max_workers: int = 1,
65
+ use_mp: bool = False,
66
+ ) -> None:
67
+ """Downloads a CELLxGENE dataset"""
68
+
69
+ with cellxgene_census.open_soma(census_version=census_version) as census:
70
+ adata = cellxgene_census.get_anndata(
71
+ census,
72
+ organism,
73
+ obs_value_filter=value_filter,
74
+ )
75
+ with random_seed(32):
76
+ indices = list(range(len(adata)))
77
+ random.shuffle(indices)
78
+ micro_batch_size: int = 32
79
+ num_steps: int = 256
80
+ selection = sorted(indices[: micro_batch_size * num_steps])
81
+ # NOTE: there's a current constraint that predict_step needs to be a function of micro-batch-size.
82
+ # this is something we are working on fixing. A quick hack is to set micro-batch-size=1, but this is
83
+ # slow. In this notebook we are going to use mbs=32 and subsample the anndata.
84
+ adata = adata[selection].copy() # so it's not a view
85
+ h5ad_outfile = DATA_PATH / Path("hs-celltype-bench.h5ad")
86
+ adata.write_h5ad(h5ad_outfile)
87
+ with tempfile.TemporaryDirectory() as temp_dir:
88
+ coll = SingleCellCollection(temp_dir)
89
+ coll.load_h5ad_multi(h5ad_outfile.parent, max_workers=max_workers, use_processes=use_mp)
90
+ coll.flatten(DATA_PATH / save_path, destroy_on_copy=True)
91
+ return DATA_PATH / save_path
92
+
93
+
94
+ @op("BioNeMo > Import H5AD file")
95
+ def import_h5ad(*, file_path: str):
96
+ return scanpy.read_h5ad(DATA_PATH / Path(file_path))
97
+
98
+
99
+ @op("BioNeMo > Download model", slow=True)
100
+ def download_model(*, model_name: str) -> str:
101
+ """Downloads a model."""
102
+ model_download_parameters = {
103
+ "geneformer_100m": {
104
+ "name": "geneformer_100m",
105
+ "version": "2.0",
106
+ "path": "geneformer_106M_240530_nemo2",
107
+ },
108
+ "geneformer_10m": {
109
+ "name": "geneformer_10m",
110
+ "version": "2.0",
111
+ "path": "geneformer_10M_240530_nemo2",
112
+ },
113
+ "geneformer_10m2": {
114
+ "name": "geneformer_10m",
115
+ "version": "2.1",
116
+ "path": "geneformer_10M_241113_nemo2",
117
+ },
118
+ }
119
+
120
+ # Define the URL and output file
121
+ url_template = "https://api.ngc.nvidia.com/v2/models/org/nvidia/team/clara/{name}/{version}/files?redirect=true&path={path}.tar.gz"
122
+ url = url_template.format(**model_download_parameters[model_name])
123
+ model_filename = f"{DATA_PATH}/{model_download_parameters[model_name]['path']}"
124
+ output_file = f"{model_filename}.tar.gz"
125
+
126
+ # Send the request
127
+ response = requests.get(url, allow_redirects=True, stream=True)
128
+ response.raise_for_status() # Raise an error for bad responses (4xx and 5xx)
129
+
130
+ # Save the file to disk
131
+ with open(f"{output_file}", "wb") as file:
132
+ for chunk in response.iter_content(chunk_size=8192):
133
+ file.write(chunk)
134
+
135
+ # Extract the tar.gz file
136
+ os.makedirs(model_filename, exist_ok=True)
137
+ with tarfile.open(output_file, "r:gz") as tar:
138
+ tar.extractall(path=model_filename)
139
+
140
+ return model_filename
141
+
142
+
143
+ @op("BioNeMo > Infer", slow=True)
144
+ def infer(dataset_path: str, model_path: str | None = None, *, results_path: str) -> str:
145
+ """Infer on a dataset."""
146
+ # This import is slow, so we only import it when we need it.
147
+ from bionemo.geneformer.scripts.infer_geneformer import infer_model
148
+
149
+ infer_model(
150
+ data_path=dataset_path,
151
+ checkpoint_path=model_path,
152
+ results_path=DATA_PATH / results_path,
153
+ include_hiddens=False,
154
+ micro_batch_size=32,
155
+ include_embeddings=True,
156
+ include_logits=False,
157
+ seq_length=2048,
158
+ precision="bf16-mixed",
159
+ devices=1,
160
+ num_nodes=1,
161
+ num_dataset_workers=10,
162
+ )
163
+ return DATA_PATH / results_path
164
+
165
+
166
+ @op("BioNeMo > Load results")
167
+ def load_results(results_path: str):
168
+ embeddings = (
169
+ torch.load(f"{results_path}/predictions__rank_0.pt")["embeddings"].float().cpu().numpy()
170
+ )
171
+ return embeddings
172
+
173
+
174
+ @op("BioNeMo > Get labels")
175
+ def get_labels(adata):
176
+ infer_metadata = adata.obs
177
+ labels = infer_metadata["cell_type"].values
178
+ label_encoder = LabelEncoder()
179
+ integer_labels = label_encoder.fit_transform(labels)
180
+ label_encoder.integer_labels = integer_labels
181
+ return label_encoder
182
+
183
+
184
+ @op("BioNeMo > Plot labels", view="visualization")
185
+ def plot_labels(adata):
186
+ infer_metadata = adata.obs
187
+ labels = infer_metadata["cell_type"].values
188
+ label_counts = Counter(labels)
189
+ labels = list(label_counts.keys())
190
+ values = list(label_counts.values())
191
+
192
+ options = {
193
+ "title": {
194
+ "text": "Cell type counts for classification dataset",
195
+ "left": "center",
196
+ },
197
+ "tooltip": {"trigger": "axis", "axisPointer": {"type": "shadow"}},
198
+ "xAxis": {
199
+ "type": "category",
200
+ "data": labels,
201
+ "axisLabel": {"rotate": 45, "align": "right"},
202
+ },
203
+ "yAxis": {"type": "value"},
204
+ "series": [
205
+ {
206
+ "name": "Count",
207
+ "type": "bar",
208
+ "data": values,
209
+ "itemStyle": {"color": "#4285F4"},
210
+ }
211
+ ],
212
+ }
213
+ return options
214
+
215
+
216
+ @op("BioNeMo > Run benchmark", slow=True)
217
+ def run_benchmark(data, labels, *, use_pca: bool = False):
218
+ """
219
+ data - contains the single cell expression (or whatever feature) in each row.
220
+ labels - contains the string label for each cell
221
+
222
+ data_shape (R, C)
223
+ labels_shape (R,)
224
+ """
225
+ np.random.seed(1337)
226
+ # Define the target dimension 'n_components'
227
+ n_components = 10 # for example, adjust based on your specific needs
228
+
229
+ # Create a pipeline that includes Gaussian random projection and RandomForestClassifier
230
+ if use_pca:
231
+ pipeline = Pipeline(
232
+ [
233
+ ("projection", PCA(n_components=n_components)),
234
+ ("classifier", RandomForestClassifier(class_weight="balanced")),
235
+ ]
236
+ )
237
+ else:
238
+ pipeline = Pipeline([("classifier", RandomForestClassifier(class_weight="balanced"))])
239
+
240
+ # Set up StratifiedKFold to ensure each fold reflects the overall distribution of labels
241
+ cv = StratifiedKFold(n_splits=5)
242
+
243
+ # Define the scoring functions
244
+ scoring = {
245
+ "accuracy": make_scorer(accuracy_score),
246
+ "precision": make_scorer(precision_score, average="macro"), # 'macro' averages over classes
247
+ "recall": make_scorer(recall_score, average="macro"),
248
+ "f1_score": make_scorer(f1_score, average="macro"),
249
+ # 'roc_auc' requires probability or decision function; hence use multi_class if applicable
250
+ "roc_auc": make_scorer(roc_auc_score, multi_class="ovr"),
251
+ }
252
+ labels = labels.integer_labels
253
+ # Perform stratified cross-validation with multiple metrics using the pipeline
254
+ results = cross_validate(
255
+ pipeline, data, labels, cv=cv, scoring=scoring, return_train_score=False
256
+ )
257
+
258
+ # Print the cross-validation results
259
+ print("Cross-validation metrics:")
260
+ results_out = {}
261
+ for metric, scores in results.items():
262
+ if metric.startswith("test_"):
263
+ results_out[metric] = (scores.mean(), scores.std())
264
+ print(f"{metric[5:]}: {scores.mean():.3f} (+/- {scores.std():.3f})")
265
+
266
+ predictions = cross_val_predict(pipeline, data, labels, cv=cv)
267
+
268
+ # v Return confusion matrix and metrics.
269
+ conf_matrix = confusion_matrix(labels, predictions)
270
+
271
+ return results_out, conf_matrix
272
+
273
+
274
+ @op("BioNeMo > Plot confusion matrix", view="visualization", slow=True)
275
+ def plot_confusion_matrix(benchmark_output, labels):
276
+ cm = benchmark_output[1]
277
+ labels = labels.classes_
278
+ str_labels = [str(label) for label in labels]
279
+ norm_cm = [[float(val / sum(row)) if sum(row) else 0 for val in row] for row in cm]
280
+ # heatmap has the 0,0 at the bottom left corner
281
+ num_rows = len(str_labels)
282
+ heatmap_data = [
283
+ [j, num_rows - i - 1, norm_cm[i][j]] for i in range(len(labels)) for j in range(len(labels))
284
+ ]
285
+
286
+ options = {
287
+ "title": {"text": "Confusion Matrix", "left": "center"},
288
+ "tooltip": {"position": "top"},
289
+ "xAxis": {
290
+ "type": "category",
291
+ "data": str_labels,
292
+ "splitArea": {"show": True},
293
+ "axisLabel": {"rotate": 70, "align": "right"},
294
+ },
295
+ "yAxis": {
296
+ "type": "category",
297
+ "data": list(reversed(str_labels)),
298
+ "splitArea": {"show": True},
299
+ },
300
+ "grid": {
301
+ "height": "70%",
302
+ "width": "70%",
303
+ "left": "20%",
304
+ "right": "10%",
305
+ "bottom": "10%",
306
+ "top": "10%",
307
+ },
308
+ "visualMap": {
309
+ "min": 0,
310
+ "max": 1,
311
+ "calculable": True,
312
+ "orient": "vertical",
313
+ "right": 10,
314
+ "top": "center",
315
+ "inRange": {"color": ["#E0F7FA", "#81D4FA", "#29B6F6", "#0288D1", "#01579B"]},
316
+ },
317
+ "series": [
318
+ {
319
+ "name": "Confusion matrix",
320
+ "type": "heatmap",
321
+ "data": heatmap_data,
322
+ "emphasis": {"itemStyle": {"borderColor": "#333", "borderWidth": 1}},
323
+ "itemStyle": {"borderColor": "#D3D3D3", "borderWidth": 2},
324
+ }
325
+ ],
326
+ }
327
+ return options
328
+
329
+
330
+ @op("BioNeMo > Plot accuracy comparison", view="visualization")
331
+ def accuracy_comparison(benchmark_output10m, benchmark_output100m):
332
+ results_10m = benchmark_output10m[0]
333
+ results_106M = benchmark_output100m[0]
334
+ data = {
335
+ "model": ["10M parameters", "106M parameters"],
336
+ "accuracy_mean": [
337
+ results_10m["test_accuracy"][0],
338
+ results_106M["test_accuracy"][0],
339
+ ],
340
+ "accuracy_std": [
341
+ results_10m["test_accuracy"][1],
342
+ results_106M["test_accuracy"][1],
343
+ ],
344
+ }
345
+
346
+ labels = data["model"] # X-axis labels
347
+ values = data["accuracy_mean"] # Y-axis values
348
+ error_bars = data["accuracy_std"] # Standard deviation for error bars
349
+
350
+ options = {
351
+ "title": {
352
+ "text": "Accuracy Comparison",
353
+ "left": "center",
354
+ "textStyle": {
355
+ "fontSize": 20, # Bigger font for title
356
+ "fontWeight": "bold", # Make title bold
357
+ },
358
+ },
359
+ "grid": {
360
+ "height": "70%",
361
+ "width": "70%",
362
+ "left": "20%",
363
+ "right": "10%",
364
+ "bottom": "10%",
365
+ "top": "10%",
366
+ },
367
+ "tooltip": {"trigger": "axis", "axisPointer": {"type": "shadow"}},
368
+ "xAxis": {
369
+ "type": "category",
370
+ "data": labels,
371
+ "axisLabel": {
372
+ "rotate": 45, # Rotate labels for better readability
373
+ "align": "right",
374
+ "textStyle": {
375
+ "fontSize": 14, # Bigger font for X-axis labels
376
+ "fontWeight": "bold",
377
+ },
378
+ },
379
+ },
380
+ "yAxis": {
381
+ "type": "value",
382
+ "name": "Accuracy",
383
+ "min": 0,
384
+ "max": 1,
385
+ "interval": 0.1, # Matches np.arange(0, 1.05, 0.05)
386
+ "axisLabel": {
387
+ "textStyle": {
388
+ "fontSize": 14, # Bigger font for X-axis labels
389
+ "fontWeight": "bold",
390
+ }
391
+ },
392
+ },
393
+ "series": [
394
+ {
395
+ "name": "Accuracy",
396
+ "type": "bar",
397
+ "data": values,
398
+ "itemStyle": {
399
+ "color": "#440154" # Viridis color palette (dark purple)
400
+ },
401
+ },
402
+ {
403
+ "name": "Error Bars",
404
+ "type": "errorbar",
405
+ "data": [[val - err, val + err] for val, err in zip(values, error_bars)],
406
+ "itemStyle": {"color": "#1f77b4"},
407
+ },
408
+ ],
409
+ }
410
+ return options
411
+
412
+
413
+ @op("BioNeMo > Plot f1 comparison", view="visualization")
414
+ def f1_comparison(benchmark_output10m, benchmark_output100m):
415
+ results_10m = benchmark_output10m[0]
416
+ results_106M = benchmark_output100m[0]
417
+ data = {
418
+ "model": ["10M parameters", "106M parameters"],
419
+ "f1_score_mean": [
420
+ results_10m["test_f1_score"][0],
421
+ results_106M["test_f1_score"][0],
422
+ ],
423
+ "f1_score_std": [
424
+ results_10m["test_f1_score"][1],
425
+ results_106M["test_f1_score"][1],
426
+ ],
427
+ }
428
+
429
+ labels = data["model"] # X-axis labels
430
+ values = data["f1_score_mean"] # Y-axis values
431
+ error_bars = data["f1_score_std"] # Standard deviation for error bars
432
+
433
+ options = {
434
+ "title": {
435
+ "text": "F1 Score Comparison",
436
+ "left": "center",
437
+ "textStyle": {
438
+ "fontSize": 20, # Bigger font for title
439
+ "fontWeight": "bold", # Make title bold
440
+ },
441
+ },
442
+ "grid": {
443
+ "height": "70%",
444
+ "width": "70%",
445
+ "left": "20%",
446
+ "right": "10%",
447
+ "bottom": "10%",
448
+ "top": "10%",
449
+ },
450
+ "tooltip": {"trigger": "axis", "axisPointer": {"type": "shadow"}},
451
+ "xAxis": {
452
+ "type": "category",
453
+ "data": labels,
454
+ "axisLabel": {
455
+ "rotate": 45, # Rotate labels for better readability
456
+ "align": "right",
457
+ "textStyle": {
458
+ "fontSize": 14, # Bigger font for X-axis labels
459
+ "fontWeight": "bold",
460
+ },
461
+ },
462
+ },
463
+ "yAxis": {
464
+ "type": "value",
465
+ "name": "F1 Score",
466
+ "min": 0,
467
+ "max": 1,
468
+ "interval": 0.1, # Matches np.arange(0, 1.05, 0.05),
469
+ "axisLabel": {
470
+ "textStyle": {
471
+ "fontSize": 14, # Bigger font for X-axis labels
472
+ "fontWeight": "bold",
473
+ }
474
+ },
475
+ },
476
+ "series": [
477
+ {
478
+ "name": "F1 Score",
479
+ "type": "bar",
480
+ "data": values,
481
+ "itemStyle": {
482
+ "color": "#440154" # Viridis color palette (dark purple)
483
+ },
484
+ },
485
+ {
486
+ "name": "Error Bars",
487
+ "type": "errorbar",
488
+ "data": [[val - err, val + err] for val, err in zip(values, error_bars)],
489
+ "itemStyle": {"color": "#1f77b4"},
490
+ },
491
+ ],
492
+ }
493
+ return options