darabos commited on
Commit
e05c3b0
·
1 Parent(s): 6ef6dd5

Make it simpler to declare operations "slow".

Browse files
lynxkite-bio/src/lynxkite_bio/nims.py CHANGED
@@ -2,13 +2,11 @@
2
 
3
  from lynxkite_graph_analytics import Bundle
4
  from lynxkite.core import ops
5
- import joblib
6
  import httpx
7
  import pandas as pd
8
  import os
9
 
10
 
11
- mem = joblib.Memory(".joblib-cache")
12
  ENV = "LynxKite Graph Analytics"
13
  op = ops.op_registration(ENV)
14
 
@@ -35,8 +33,7 @@ async def query_bionemo_nim(
35
  raise ValueError(f"Query failed: {e}")
36
 
37
 
38
- @op("MSA-search")
39
- @mem.cache
40
  async def msa_search(
41
  bundle: Bundle,
42
  *,
@@ -74,8 +71,7 @@ async def msa_search(
74
  return bundle
75
 
76
 
77
- @op("Query OpenFold2")
78
- @mem.cache
79
  async def query_openfold2(
80
  bundle: Bundle,
81
  *,
@@ -135,8 +131,7 @@ def view_molecule(
135
  }
136
 
137
 
138
- @op("Query GenMol")
139
- @mem.cache
140
  async def query_genmol(
141
  bundle: Bundle,
142
  *,
@@ -166,8 +161,7 @@ async def query_genmol(
166
  return bundle
167
 
168
 
169
- @op("Query DiffDock")
170
- @mem.cache
171
  async def query_diffdock(
172
  proteins: Bundle,
173
  ligands: Bundle,
 
2
 
3
  from lynxkite_graph_analytics import Bundle
4
  from lynxkite.core import ops
 
5
  import httpx
6
  import pandas as pd
7
  import os
8
 
9
 
 
10
  ENV = "LynxKite Graph Analytics"
11
  op = ops.op_registration(ENV)
12
 
 
33
  raise ValueError(f"Query failed: {e}")
34
 
35
 
36
+ @op("MSA-search", slow=True)
 
37
  async def msa_search(
38
  bundle: Bundle,
39
  *,
 
71
  return bundle
72
 
73
 
74
+ @op("Query OpenFold2", slow=True)
 
75
  async def query_openfold2(
76
  bundle: Bundle,
77
  *,
 
131
  }
132
 
133
 
134
+ @op("Query GenMol", slow=True)
 
135
  async def query_genmol(
136
  bundle: Bundle,
137
  *,
 
161
  return bundle
162
 
163
 
164
+ @op("Query DiffDock", slow=True)
 
165
  async def query_diffdock(
166
  proteins: Bundle,
167
  ligands: Bundle,
lynxkite-bio/src/lynxkite_bio/rdkit.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  from lynxkite_graph_analytics import Bundle, RelationDefinition
4
  from lynxkite.core import ops
5
- import joblib
6
  import numpy as np
7
  import pandas as pd
8
  import rdkit.Chem
@@ -10,7 +9,6 @@ import rdkit.Chem.rdFingerprintGenerator
10
  import rdkit.Chem.Fingerprints.ClusterMols
11
  import scipy
12
 
13
- mem = joblib.Memory(".joblib-cache")
14
  ENV = "LynxKite Graph Analytics"
15
  op = ops.op_registration(ENV)
16
 
 
2
 
3
  from lynxkite_graph_analytics import Bundle, RelationDefinition
4
  from lynxkite.core import ops
 
5
  import numpy as np
6
  import pandas as pd
7
  import rdkit.Chem
 
9
  import rdkit.Chem.Fingerprints.ClusterMols
10
  import scipy
11
 
 
12
  ENV = "LynxKite Graph Analytics"
13
  op = ops.op_registration(ENV)
14
 
lynxkite-core/src/lynxkite/core/ops.py CHANGED
@@ -255,7 +255,7 @@ def op(
255
  func = matplotlib_to_image(func)
256
  if slow:
257
  func = mem.cache(func)
258
- func = _global_slow(func)
259
  # Positional arguments are inputs.
260
  inputs = [
261
  Input(name=name, type=param.annotation)
@@ -385,9 +385,13 @@ def passive_op_registration(env: str):
385
  return functools.partial(register_passive_op, env)
386
 
387
 
388
- def slow(func):
389
  """Decorator for slow, blocking operations. Turns them into separate threads."""
390
 
 
 
 
 
391
  @functools.wraps(func)
392
  async def wrapper(*args, **kwargs):
393
  return await asyncio.to_thread(func, *args, **kwargs)
@@ -395,7 +399,6 @@ def slow(func):
395
  return wrapper
396
 
397
 
398
- _global_slow = slow # For access inside op().
399
  CATALOGS_SNAPSHOTS: dict[str, Catalogs] = {}
400
 
401
 
 
255
  func = matplotlib_to_image(func)
256
  if slow:
257
  func = mem.cache(func)
258
+ func = make_async(func)
259
  # Positional arguments are inputs.
260
  inputs = [
261
  Input(name=name, type=param.annotation)
 
385
  return functools.partial(register_passive_op, env)
386
 
387
 
388
+ def make_async(func):
389
  """Decorator for slow, blocking operations. Turns them into separate threads."""
390
 
391
+ if asyncio.iscoroutinefunction(func):
392
+ # If the function is already a coroutine, return it as is.
393
+ return func
394
+
395
  @functools.wraps(func)
396
  async def wrapper(*args, **kwargs):
397
  return await asyncio.to_thread(func, *args, **kwargs)
 
399
  return wrapper
400
 
401
 
 
402
  CATALOGS_SNAPSHOTS: dict[str, Catalogs] = {}
403
 
404
 
lynxkite-graph-analytics/src/lynxkite_graph_analytics/bionemo_ops.py CHANGED
@@ -12,7 +12,6 @@ import tarfile
12
  import os
13
  from collections import Counter
14
  from . import core
15
- import joblib
16
  import numpy as np
17
  import torch
18
  from pathlib import Path
@@ -40,7 +39,6 @@ from bionemo.scdl.io.single_cell_collection import SingleCellCollection
40
  import scanpy
41
 
42
 
43
- mem = joblib.Memory(".joblib-cache")
44
  op = ops.op_registration(core.ENV)
45
  DATA_PATH = Path("/workspace")
46
 
@@ -56,8 +54,7 @@ def random_seed(seed: int):
56
  random.setstate(state)
57
 
58
 
59
- @op("BioNeMo > Download CELLxGENE dataset")
60
- @mem.cache()
61
  def download_cellxgene_dataset(
62
  *,
63
  save_path: str,
@@ -99,8 +96,7 @@ def import_h5ad(*, file_path: str):
99
  return scanpy.read_h5ad(DATA_PATH / Path(file_path))
100
 
101
 
102
- @op("BioNeMo > Download model")
103
- @mem.cache(verbose=1)
104
  def download_model(*, model_name: str) -> str:
105
  """Downloads a model."""
106
  model_download_parameters = {
@@ -144,8 +140,7 @@ def download_model(*, model_name: str) -> str:
144
  return model_filename
145
 
146
 
147
- @op("BioNeMo > Infer")
148
- @mem.cache(verbose=1)
149
  def infer(dataset_path: str, model_path: str | None = None, *, results_path: str) -> str:
150
  """Infer on a dataset."""
151
  # This import is slow, so we only import it when we need it.
@@ -218,8 +213,7 @@ def plot_labels(adata):
218
  return options
219
 
220
 
221
- @op("BioNeMo > Run benchmark")
222
- @mem.cache(verbose=1)
223
  def run_benchmark(data, labels, *, use_pca: bool = False):
224
  """
225
  data - contains the single cell expression (or whatever feature) in each row.
@@ -277,8 +271,7 @@ def run_benchmark(data, labels, *, use_pca: bool = False):
277
  return results_out, conf_matrix
278
 
279
 
280
- @op("BioNeMo > Plot confusion matrix", view="visualization")
281
- @mem.cache(verbose=1)
282
  def plot_confusion_matrix(benchmark_output, labels):
283
  cm = benchmark_output[1]
284
  labels = labels.classes_
 
12
  import os
13
  from collections import Counter
14
  from . import core
 
15
  import numpy as np
16
  import torch
17
  from pathlib import Path
 
39
  import scanpy
40
 
41
 
 
42
  op = ops.op_registration(core.ENV)
43
  DATA_PATH = Path("/workspace")
44
 
 
54
  random.setstate(state)
55
 
56
 
57
+ @op("BioNeMo > Download CELLxGENE dataset", slow=True)
 
58
  def download_cellxgene_dataset(
59
  *,
60
  save_path: str,
 
96
  return scanpy.read_h5ad(DATA_PATH / Path(file_path))
97
 
98
 
99
+ @op("BioNeMo > Download model", slow=True)
 
100
  def download_model(*, model_name: str) -> str:
101
  """Downloads a model."""
102
  model_download_parameters = {
 
140
  return model_filename
141
 
142
 
143
+ @op("BioNeMo > Infer", slow=True)
 
144
  def infer(dataset_path: str, model_path: str | None = None, *, results_path: str) -> str:
145
  """Infer on a dataset."""
146
  # This import is slow, so we only import it when we need it.
 
213
  return options
214
 
215
 
216
+ @op("BioNeMo > Run benchmark", slow=True)
 
217
  def run_benchmark(data, labels, *, use_pca: bool = False):
218
  """
219
  data - contains the single cell expression (or whatever feature) in each row.
 
271
  return results_out, conf_matrix
272
 
273
 
274
+ @op("BioNeMo > Plot confusion matrix", view="visualization", slow=True)
 
275
  def plot_confusion_matrix(benchmark_output, labels):
276
  cm = benchmark_output[1]
277
  labels = labels.classes_
lynxkite-graph-analytics/src/lynxkite_graph_analytics/ml_ops.py CHANGED
@@ -8,12 +8,10 @@ from lynxkite.core import workspace
8
  from .pytorch import pytorch_core
9
  from lynxkite.core import ops
10
  from tqdm import tqdm
11
- import joblib
12
  import pandas as pd
13
  import pathlib
14
 
15
 
16
- mem = joblib.Memory(".joblib-cache")
17
  op = ops.op_registration(core.ENV)
18
 
19
 
@@ -57,8 +55,7 @@ class ModelOutputMapping(pytorch_core.ModelMapping):
57
  pass
58
 
59
 
60
- @op("Train model")
61
- @ops.slow
62
  def train_model(
63
  bundle: core.Bundle,
64
  *,
@@ -82,8 +79,7 @@ def train_model(
82
  return bundle
83
 
84
 
85
- @op("Model inference")
86
- @ops.slow
87
  def model_inference(
88
  bundle: core.Bundle,
89
  *,
 
8
  from .pytorch import pytorch_core
9
  from lynxkite.core import ops
10
  from tqdm import tqdm
 
11
  import pandas as pd
12
  import pathlib
13
 
14
 
 
15
  op = ops.op_registration(core.ENV)
16
 
17
 
 
55
  pass
56
 
57
 
58
+ @op("Train model", slow=True)
 
59
  def train_model(
60
  bundle: core.Bundle,
61
  *,
 
79
  return bundle
80
 
81
 
82
+ @op("Model inference", slow=True)
 
83
  def model_inference(
84
  bundle: core.Bundle,
85
  *,
lynxkite-graph-analytics/src/lynxkite_graph_analytics/networkx_ops.py CHANGED
@@ -156,7 +156,7 @@ def wrapped(name: str, func):
156
  for k, v in kwargs.items():
157
  if v == "None":
158
  kwargs[k] = None
159
- res = await ops.slow(func)(*args, **kwargs)
160
  # Figure out what the returned value is.
161
  if isinstance(res, nx.Graph):
162
  return res
 
156
  for k, v in kwargs.items():
157
  if v == "None":
158
  kwargs[k] = None
159
+ res = await ops.run_in_thread(func)(*args, **kwargs)
160
  # Figure out what the returned value is.
161
  if isinstance(res, nx.Graph):
162
  return res
lynxkite-lynxscribe/src/lynxkite_lynxscribe/lynxscribe_ops.py CHANGED
@@ -8,7 +8,6 @@ from copy import deepcopy
8
  from enum import Enum
9
  import asyncio
10
  import pandas as pd
11
- import joblib
12
  from pydantic import BaseModel, ConfigDict
13
 
14
  import pathlib
@@ -39,7 +38,6 @@ DEFAULT_NEGATIVE_ANSWER = "I'm sorry, but the data I've been trained on does not
39
 
40
  ENV = "LynxScribe"
41
  one_by_one.register(ENV)
42
- mem = joblib.Memory("joblib-cache")
43
  op = ops.op_registration(ENV)
44
  output_on_top = ops.output_position(output="top")
45
 
@@ -149,8 +147,7 @@ def cloud_file_loader(
149
 
150
 
151
  # @output_on_top
152
- # @op("LynxScribe RAG Graph Vector Store")
153
- # @mem.cache
154
  # def ls_rag_graph(
155
  # *,
156
  # name: str = "faiss",
@@ -187,8 +184,7 @@ def cloud_file_loader(
187
  # return {"rag_graph": rag_graph}
188
 
189
 
190
- @op("LynxScribe Image Describer")
191
- @mem.cache
192
  async def ls_image_describer(
193
  file_urls,
194
  *,
@@ -251,8 +247,7 @@ async def ls_image_describer(
251
  return {"image_descriptions": image_descriptions}
252
 
253
 
254
- @op("LynxScribe Image RAG Builder")
255
- @mem.cache
256
  async def ls_image_rag_builder(
257
  image_descriptions,
258
  *,
@@ -407,8 +402,7 @@ def view_image(embedding_similarities):
407
  return embedding_similarities[0]["image_url"]
408
 
409
 
410
- @op("LynxScribe Text RAG Loader")
411
- @mem.cache
412
  def ls_text_rag_loader(
413
  file_urls,
414
  *,
@@ -465,8 +459,7 @@ def ls_text_rag_loader(
465
  return {"rag_graph": rag_graph}
466
 
467
 
468
- @op("LynxScribe FAQ to RAG")
469
- @mem.cache
470
  async def ls_faq_to_rag(
471
  *,
472
  faq_excel_path: str = "",
@@ -712,8 +705,7 @@ def read_excel(*, file_path: str, sheet_name: str = "Sheet1", columns: str = "")
712
 
713
 
714
  @ops.input_position(system_prompt="bottom", instruction_prompt="bottom", dataframe="left")
715
- @op("LynxScribe Task Solver")
716
- @mem.cache
717
  async def ls_task_solver(
718
  system_prompt,
719
  instruction_prompt,
@@ -814,7 +806,7 @@ def mask(*, name="", regex="", exceptions="", mask_pattern=""):
814
 
815
 
816
  @ops.input_position(chat_api="bottom")
817
- @op("Test Chat API")
818
  async def test_chat_api(message, chat_api, *, show_details=False):
819
  chat_api = chat_api[0]["chat_api"]
820
  request = ChatCompletionPrompt(
 
8
  from enum import Enum
9
  import asyncio
10
  import pandas as pd
 
11
  from pydantic import BaseModel, ConfigDict
12
 
13
  import pathlib
 
38
 
39
  ENV = "LynxScribe"
40
  one_by_one.register(ENV)
 
41
  op = ops.op_registration(ENV)
42
  output_on_top = ops.output_position(output="top")
43
 
 
147
 
148
 
149
  # @output_on_top
150
+ # @op("LynxScribe RAG Graph Vector Store", slow=True)
 
151
  # def ls_rag_graph(
152
  # *,
153
  # name: str = "faiss",
 
184
  # return {"rag_graph": rag_graph}
185
 
186
 
187
+ @op("LynxScribe Image Describer", slow=True)
 
188
  async def ls_image_describer(
189
  file_urls,
190
  *,
 
247
  return {"image_descriptions": image_descriptions}
248
 
249
 
250
+ @op("LynxScribe Image RAG Builder", slow=True)
 
251
  async def ls_image_rag_builder(
252
  image_descriptions,
253
  *,
 
402
  return embedding_similarities[0]["image_url"]
403
 
404
 
405
+ @op("LynxScribe Text RAG Loader", slow=True)
 
406
  def ls_text_rag_loader(
407
  file_urls,
408
  *,
 
459
  return {"rag_graph": rag_graph}
460
 
461
 
462
+ @op("LynxScribe FAQ to RAG", slow=True)
 
463
  async def ls_faq_to_rag(
464
  *,
465
  faq_excel_path: str = "",
 
705
 
706
 
707
  @ops.input_position(system_prompt="bottom", instruction_prompt="bottom", dataframe="left")
708
+ @op("LynxScribe Task Solver", slow=True)
 
709
  async def ls_task_solver(
710
  system_prompt,
711
  instruction_prompt,
 
806
 
807
 
808
  @ops.input_position(chat_api="bottom")
809
+ @op("Test Chat API", slow=True)
810
  async def test_chat_api(message, chat_api, *, show_details=False):
811
  chat_api = chat_api[0]["chat_api"]
812
  request = ChatCompletionPrompt(