darabos's picture
Training and inference.
ad5c4e4
raw
history blame
8.99 kB
"""Boxes for defining PyTorch models."""
import graphlib
from lynxkite.core import ops, workspace
from lynxkite.core.ops import Parameter as P
import torch
import torch_geometric as pyg
from dataclasses import dataclass
from . import core
ENV = "PyTorch model"
def reg(name, inputs=[], outputs=None, params=[]):
if outputs is None:
outputs = inputs
return ops.register_passive_op(
ENV,
name,
inputs=[
ops.Input(name=name, position="bottom", type="tensor") for name in inputs
],
outputs=[
ops.Output(name=name, position="top", type="tensor") for name in outputs
],
params=params,
)
reg("Input: embedding", outputs=["x"])
reg("Input: graph edges", outputs=["edges"])
reg("Input: label", outputs=["y"])
reg("Input: positive sample", outputs=["x_pos"])
reg("Input: negative sample", outputs=["x_neg"])
reg("Input: sequential", outputs=["y"])
reg("Input: zeros", outputs=["x"])
reg("LSTM", inputs=["x", "h"], outputs=["x", "h"])
reg(
"Neural ODE",
inputs=["x"],
params=[
P.basic("relative_tolerance"),
P.basic("absolute_tolerance"),
P.options(
"method",
[
"dopri8",
"dopri5",
"bosh3",
"fehlberg2",
"adaptive_heun",
"euler",
"midpoint",
"rk4",
"explicit_adams",
"implicit_adams",
],
),
],
)
reg("Attention", inputs=["q", "k", "v"], outputs=["x", "weights"])
reg("LayerNorm", inputs=["x"])
reg("Dropout", inputs=["x"], params=[P.basic("p", 0.5)])
reg("Linear", inputs=["x"], params=[P.basic("output_dim", "same")])
reg("Softmax", inputs=["x"])
reg(
"Graph conv",
inputs=["x", "edges"],
outputs=["x"],
params=[P.options("type", ["GCNConv", "GATConv", "GATv2Conv", "SAGEConv"])],
)
reg(
"Activation",
inputs=["x"],
params=[P.options("type", ["ReLU", "Leaky ReLU", "Tanh", "Mish"])],
)
reg("Concatenate", inputs=["a", "b"], outputs=["x"])
reg("Add", inputs=["a", "b"], outputs=["x"])
reg("Subtract", inputs=["a", "b"], outputs=["x"])
reg("Multiply", inputs=["a", "b"], outputs=["x"])
reg("MSE loss", inputs=["x", "y"], outputs=["loss"])
reg("Triplet margin loss", inputs=["x", "x_pos", "x_neg"], outputs=["loss"])
reg("Cross-entropy loss", inputs=["x", "y"], outputs=["loss"])
reg(
"Optimizer",
inputs=["loss"],
outputs=[],
params=[
P.options(
"type",
[
"AdamW",
"Adafactor",
"Adagrad",
"SGD",
"Lion",
"Paged AdamW",
"Galore AdamW",
],
),
P.basic("lr", 0.001),
],
)
ops.register_passive_op(
ENV,
"Repeat",
inputs=[ops.Input(name="input", position="top", type="tensor")],
outputs=[ops.Output(name="output", position="bottom", type="tensor")],
params=[
ops.Parameter.basic("times", 1, int),
ops.Parameter.basic("same_weights", True, bool),
],
)
ops.register_passive_op(
ENV,
"Recurrent chain",
inputs=[ops.Input(name="input", position="top", type="tensor")],
outputs=[ops.Output(name="output", position="bottom", type="tensor")],
params=[],
)
def _to_id(s: str) -> str:
"""Replaces all non-alphanumeric characters with underscores."""
return "".join(c if c.isalnum() else "_" for c in s)
@dataclass
class ModelConfig:
model: torch.nn.Module
model_inputs: list[str]
model_outputs: list[str]
loss_inputs: list[str]
loss: torch.nn.Module
optimizer: torch.optim.Optimizer
def _forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
model_inputs = [inputs[i] for i in self.model_inputs]
output = self.model(*model_inputs)
if not isinstance(output, tuple):
output = (output,)
values = {k: v for k, v in zip(self.model_outputs, output)}
return values
def inference(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
# TODO: Do multiple batches.
self.model.eval()
return self._forward(inputs)
def train(self, inputs: dict[str, torch.Tensor]) -> float:
"""Train the model for one epoch. Returns the loss."""
# TODO: Do multiple batches.
self.model.train()
self.optimizer.zero_grad()
values = self._forward(inputs)
values.update(inputs)
loss_inputs = [values[i] for i in self.loss_inputs]
loss = self.loss(*loss_inputs)
loss.backward()
self.optimizer.step()
return loss.item()
def copy(self):
"""Returns a copy of the model."""
c = super().copy()
c.model = self.model.copy()
return c
def build_model(
ws: workspace.Workspace, inputs: dict[str, torch.Tensor]
) -> ModelConfig:
"""Builds the model described in the workspace."""
catalog = ops.CATALOGS[ENV]
optimizers = []
nodes = {}
for node in ws.nodes:
nodes[node.id] = node
if node.data.title == "Optimizer":
optimizers.append(node.id)
assert optimizers, "No optimizer found."
assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}"
[optimizer] = optimizers
dependencies = {n.id: [] for n in ws.nodes}
edges = {}
# TODO: Dissolve repeat boxes here.
for e in ws.edges:
dependencies[e.target].append(e.source)
edges.setdefault((e.target, e.targetHandle), []).append(
(e.source, e.sourceHandle)
)
sizes = {}
for k, i in inputs.items():
sizes[k] = i.shape[-1]
ts = graphlib.TopologicalSorter(dependencies)
layers = []
loss_layers = []
in_loss = set()
cfg = {}
loss_inputs = set()
used_inputs = set()
for node_id in ts.static_order():
node = nodes[node_id]
t = node.data.title
op = catalog[t]
p = op.convert_params(node.data.params)
for b in dependencies[node_id]:
if b in in_loss:
in_loss.add(node_id)
ls = loss_layers if node_id in in_loss else layers
nid = _to_id(node_id)
match t:
case "Linear":
[(ib, ih)] = edges[node_id, "x"]
i = _to_id(ib) + "_" + ih
used_inputs.add(i)
isize = sizes[i]
osize = isize if p["output_dim"] == "same" else int(p["output_dim"])
ls.append((torch.nn.Linear(isize, osize), f"{i} -> {nid}_x"))
sizes[f"{nid}_x"] = osize
case "Activation":
[(ib, ih)] = edges[node_id, "x"]
i = _to_id(ib) + "_" + ih
used_inputs.add(i)
f = getattr(
torch.nn.functional, p["type"].name.lower().replace(" ", "_")
)
ls.append((f, f"{i} -> {nid}_x"))
sizes[f"{nid}_x"] = sizes[i]
case "MSE loss":
[(xb, xh)] = edges[node_id, "x"]
xi = _to_id(xb) + "_" + xh
[(yb, yh)] = edges[node_id, "y"]
yi = _to_id(yb) + "_" + yh
loss_inputs.add(xi)
loss_inputs.add(yi)
in_loss.add(node_id)
loss_layers.append(
(torch.nn.functional.mse_loss, f"{xi}, {yi} -> {nid}_loss")
)
cfg["model_inputs"] = used_inputs & inputs.keys()
cfg["model_outputs"] = loss_inputs - inputs.keys()
cfg["loss_inputs"] = loss_inputs
# Make sure the trained output is output from the last model layer.
outputs = ", ".join(cfg["model_outputs"])
layers.append((torch.nn.Identity(), f"{outputs} -> {outputs}"))
# Create model.
cfg["model"] = pyg.nn.Sequential(", ".join(used_inputs & inputs.keys()), layers)
# Make sure the loss is output from the last loss layer.
[(lossb, lossh)] = edges[optimizer, "loss"]
lossi = _to_id(lossb) + "_" + lossh
loss_layers.append((torch.nn.Identity(), f"{lossi} -> loss"))
# Create loss function.
cfg["loss"] = pyg.nn.Sequential(", ".join(loss_inputs), loss_layers)
assert not list(cfg["loss"].parameters()), (
f"loss should have no parameters: {list(cfg['loss'].parameters())}"
)
# Create optimizer.
op = catalog["Optimizer"]
p = op.convert_params(nodes[optimizer].data.params)
o = getattr(torch.optim, p["type"].name)
cfg["optimizer"] = o(cfg["model"].parameters(), lr=p["lr"])
return ModelConfig(**cfg)
def to_tensors(b: core.Bundle, m: dict[str, dict]) -> dict[str, torch.Tensor]:
"""Converts a tensor to the correct type for PyTorch."""
tensors = {}
for k, v in m.items():
tensors[k] = torch.tensor(
b.dfs[v["df"]][v["column"]].to_list(), dtype=torch.float32
)
return tensors