darabos commited on
Commit
1ffbcc2
·
1 Parent(s): 2c239c4

Dependency-based repeat instead of using regions.

Browse files
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch_model_ops.py CHANGED
@@ -246,6 +246,34 @@ class ModelConfig:
246
  }
247
 
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> ModelConfig:
250
  """Builds the model described in the workspace."""
251
  catalog = ops.CATALOGS[ENV]
@@ -259,6 +287,7 @@ def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> Mod
259
  assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}"
260
  [optimizer] = optimizers
261
  dependencies = {n.id: [] for n in ws.nodes}
 
262
  in_edges = {}
263
  out_edges = {}
264
  repeats = []
@@ -266,6 +295,7 @@ def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> Mod
266
  if nodes[e.target].data.title == "Repeat":
267
  repeats.append(e.target)
268
  dependencies[e.target].append(e.source)
 
269
  in_edges.setdefault(e.target, {}).setdefault(e.targetHandle, []).append(
270
  (e.source, e.sourceHandle)
271
  )
@@ -351,7 +381,7 @@ def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> Mod
351
  for out in out_edges.get(node_id, []):
352
  i = _to_id(node_id, out)
353
  outputs[out] = i
354
- if inputs: # Nodes with no inputs are input nodes. Their outputs are not "made" by us.
355
  if "loss" in regions[node_id]:
356
  made_in_loss.add(i)
357
  else:
@@ -374,31 +404,23 @@ def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> Mod
374
  regions[node_id].add(("repeat", node_id.removeprefix("START ")))
375
  else:
376
  repeat_id = node_id.removeprefix("END ")
 
377
  print(f"repeat {repeat_id} ending")
 
 
 
 
 
 
 
 
378
  regions[node_id].remove(("repeat", repeat_id))
379
- for n in nodes:
380
- r = regions.get(n, set())
381
- if ("repeat", repeat_id) in r:
382
- print(f"repeating {n}")
383
  case "Optimizer" | "Input: tensor" | "Input: graph edges" | "Input: sequential":
384
  pass
385
  case _:
386
- op_inputs = []
387
- for i in op.inputs.keys():
388
- id = getattr(inputs, i)
389
- op_inputs.append(OpInput(id, shape=sizes.get(id, 1)))
390
- if op.func != ops.no_op:
391
- layer = op.func(*op_inputs, **p)
392
- else:
393
- layer = Layer(torch.nn.Identity(), shapes=[i.shape for i in op_inputs])
394
- input_ids = ", ".join(i.id for i in op_inputs)
395
- output_ids = []
396
- for o, shape in zip(op.outputs.keys(), layer.shapes):
397
- id = getattr(outputs, o)
398
- output_ids.append(id)
399
- sizes[id] = shape
400
- output_ids = ", ".join(output_ids)
401
- ls.append((layer.module, f"{input_ids} -> {output_ids}"))
402
  cfg["model_inputs"] = list(used_in_model - made_in_model)
403
  cfg["model_outputs"] = list(made_in_model & used_in_loss)
404
  cfg["loss_inputs"] = list(used_in_loss - made_in_loss)
 
246
  }
247
 
248
 
249
+ def _add_op(op, params, inputs, outputs, sizes, layers):
250
+ op_inputs = []
251
+ for i in op.inputs.keys():
252
+ id = getattr(inputs, i)
253
+ op_inputs.append(OpInput(id, shape=sizes.get(id, 1)))
254
+ if op.func != ops.no_op:
255
+ layer = op.func(*op_inputs, **params)
256
+ else:
257
+ layer = Layer(torch.nn.Identity(), shapes=[i.shape for i in op_inputs])
258
+ input_ids = ", ".join(i.id for i in op_inputs)
259
+ output_ids = []
260
+ for o, shape in zip(op.outputs.keys(), layer.shapes):
261
+ id = getattr(outputs, o)
262
+ output_ids.append(id)
263
+ sizes[id] = shape
264
+ output_ids = ", ".join(output_ids)
265
+ layers.append((layer.module, f"{input_ids} -> {output_ids}"))
266
+
267
+
268
+ def _all_dependencies(node: str, dependencies: dict[str, list[str]]) -> set[str]:
269
+ """Returns all dependencies of a node."""
270
+ deps = set()
271
+ for dep in dependencies[node]:
272
+ deps.add(dep)
273
+ deps.update(_all_dependencies(dep, dependencies))
274
+ return deps
275
+
276
+
277
  def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> ModelConfig:
278
  """Builds the model described in the workspace."""
279
  catalog = ops.CATALOGS[ENV]
 
287
  assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}"
288
  [optimizer] = optimizers
289
  dependencies = {n.id: [] for n in ws.nodes}
290
+ inv_dependencies = {n.id: [] for n in ws.nodes}
291
  in_edges = {}
292
  out_edges = {}
293
  repeats = []
 
295
  if nodes[e.target].data.title == "Repeat":
296
  repeats.append(e.target)
297
  dependencies[e.target].append(e.source)
298
+ inv_dependencies[e.source].append(e.target)
299
  in_edges.setdefault(e.target, {}).setdefault(e.targetHandle, []).append(
300
  (e.source, e.sourceHandle)
301
  )
 
381
  for out in out_edges.get(node_id, []):
382
  i = _to_id(node_id, out)
383
  outputs[out] = i
384
+ if not t.startswith("Input:"): # The outputs of inputs are not "made" by us.
385
  if "loss" in regions[node_id]:
386
  made_in_loss.add(i)
387
  else:
 
404
  regions[node_id].add(("repeat", node_id.removeprefix("START ")))
405
  else:
406
  repeat_id = node_id.removeprefix("END ")
407
+ start_id = f"START {repeat_id}"
408
  print(f"repeat {repeat_id} ending")
409
+ after_start = _all_dependencies(start_id, inv_dependencies)
410
+ after_end = _all_dependencies(node_id, inv_dependencies)
411
+ before_end = _all_dependencies(node_id, dependencies)
412
+ affected_nodes = after_start - after_end
413
+ repeated_nodes = after_start & before_end
414
+ assert affected_nodes == repeated_nodes, (
415
+ f"edges leave repeated section '{repeat_id}':\n{affected_nodes - repeated_nodes}"
416
+ )
417
  regions[node_id].remove(("repeat", repeat_id))
418
+ for n in repeated_nodes:
419
+ print(f"repeating {n}")
 
 
420
  case "Optimizer" | "Input: tensor" | "Input: graph edges" | "Input: sequential":
421
  pass
422
  case _:
423
+ _add_op(op, p, inputs, outputs, sizes, ls)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  cfg["model_inputs"] = list(used_in_model - made_in_model)
425
  cfg["model_outputs"] = list(made_in_model & used_in_loss)
426
  cfg["loss_inputs"] = list(used_in_loss - made_in_loss)