darabos commited on
Commit
0a73f59
·
2 Parent(s): f56dc7e 13fabc5

Merge pull request #204 from biggraph/darabos-slow

Browse files

Make it simpler to declare operations "slow"

lynxkite-core/src/lynxkite/core/ops.py CHANGED
@@ -255,7 +255,7 @@ def op(
255
  func = matplotlib_to_image(func)
256
  if slow:
257
  func = mem.cache(func)
258
- func = _global_slow(func)
259
  # Positional arguments are inputs.
260
  inputs = [
261
  Input(name=name, type=param.annotation)
@@ -385,9 +385,13 @@ def passive_op_registration(env: str):
385
  return functools.partial(register_passive_op, env)
386
 
387
 
388
- def slow(func):
389
  """Decorator for slow, blocking operations. Turns them into separate threads."""
390
 
 
 
 
 
391
  @functools.wraps(func)
392
  async def wrapper(*args, **kwargs):
393
  return await asyncio.to_thread(func, *args, **kwargs)
@@ -395,7 +399,6 @@ def slow(func):
395
  return wrapper
396
 
397
 
398
- _global_slow = slow # For access inside op().
399
  CATALOGS_SNAPSHOTS: dict[str, Catalogs] = {}
400
 
401
 
 
255
  func = matplotlib_to_image(func)
256
  if slow:
257
  func = mem.cache(func)
258
+ func = make_async(func)
259
  # Positional arguments are inputs.
260
  inputs = [
261
  Input(name=name, type=param.annotation)
 
385
  return functools.partial(register_passive_op, env)
386
 
387
 
388
+ def make_async(func):
389
  """Decorator for slow, blocking operations. Turns them into separate threads."""
390
 
391
+ if asyncio.iscoroutinefunction(func):
392
+ # If the function is already a coroutine, return it as is.
393
+ return func
394
+
395
  @functools.wraps(func)
396
  async def wrapper(*args, **kwargs):
397
  return await asyncio.to_thread(func, *args, **kwargs)
 
399
  return wrapper
400
 
401
 
 
402
  CATALOGS_SNAPSHOTS: dict[str, Catalogs] = {}
403
 
404
 
lynxkite-graph-analytics/src/lynxkite_graph_analytics/ml_ops.py CHANGED
@@ -8,12 +8,10 @@ from lynxkite.core import workspace
8
  from .pytorch import pytorch_core
9
  from lynxkite.core import ops
10
  from tqdm import tqdm
11
- import joblib
12
  import pandas as pd
13
  import pathlib
14
 
15
 
16
- mem = joblib.Memory(".joblib-cache")
17
  op = ops.op_registration(core.ENV)
18
 
19
 
@@ -57,8 +55,7 @@ class ModelOutputMapping(pytorch_core.ModelMapping):
57
  pass
58
 
59
 
60
- @op("Train model")
61
- @ops.slow
62
  def train_model(
63
  bundle: core.Bundle,
64
  *,
@@ -82,8 +79,7 @@ def train_model(
82
  return bundle
83
 
84
 
85
- @op("Model inference")
86
- @ops.slow
87
  def model_inference(
88
  bundle: core.Bundle,
89
  *,
 
8
  from .pytorch import pytorch_core
9
  from lynxkite.core import ops
10
  from tqdm import tqdm
 
11
  import pandas as pd
12
  import pathlib
13
 
14
 
 
15
  op = ops.op_registration(core.ENV)
16
 
17
 
 
55
  pass
56
 
57
 
58
+ @op("Train model", slow=True)
 
59
  def train_model(
60
  bundle: core.Bundle,
61
  *,
 
79
  return bundle
80
 
81
 
82
+ @op("Model inference", slow=True)
 
83
  def model_inference(
84
  bundle: core.Bundle,
85
  *,
lynxkite-graph-analytics/src/lynxkite_graph_analytics/networkx_ops.py CHANGED
@@ -156,7 +156,7 @@ def wrapped(name: str, func):
156
  for k, v in kwargs.items():
157
  if v == "None":
158
  kwargs[k] = None
159
- res = await ops.slow(func)(*args, **kwargs)
160
  # Figure out what the returned value is.
161
  if isinstance(res, nx.Graph):
162
  return res
 
156
  for k, v in kwargs.items():
157
  if v == "None":
158
  kwargs[k] = None
159
+ res = await ops.run_in_thread(func)(*args, **kwargs)
160
  # Figure out what the returned value is.
161
  if isinstance(res, nx.Graph):
162
  return res