Spaces:
Running
Running
Make new optimizer when model is copied.
Browse files
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch_model_ops.py
CHANGED
@@ -189,10 +189,14 @@ class ModelConfig:
|
|
189 |
model_outputs: list[str]
|
190 |
loss_inputs: list[str]
|
191 |
loss: torch.nn.Module
|
192 |
-
|
|
|
193 |
source_workspace: str | None = None
|
194 |
trained: bool = False
|
195 |
|
|
|
|
|
|
|
196 |
def num_parameters(self) -> int:
|
197 |
return sum(p.numel() for p in self.model.parameters())
|
198 |
|
@@ -222,10 +226,20 @@ class ModelConfig:
|
|
222 |
self.optimizer.step()
|
223 |
return loss.item()
|
224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
def copy(self):
|
226 |
"""Returns a copy of the model."""
|
227 |
-
c = dataclasses.replace(
|
228 |
-
|
|
|
|
|
|
|
|
|
229 |
return c
|
230 |
|
231 |
def metadata(self):
|
@@ -451,9 +465,7 @@ class ModelBuilder:
|
|
451 |
assert not list(cfg["loss"].parameters()), f"loss should have no parameters: {loss_layers}"
|
452 |
# Create optimizer.
|
453 |
op = self.catalog["Optimizer"]
|
454 |
-
|
455 |
-
o = getattr(torch.optim, p["type"].name)
|
456 |
-
cfg["optimizer"] = o(cfg["model"].parameters(), lr=p["lr"])
|
457 |
return ModelConfig(**cfg)
|
458 |
|
459 |
|
|
|
189 |
model_outputs: list[str]
|
190 |
loss_inputs: list[str]
|
191 |
loss: torch.nn.Module
|
192 |
+
optimizer_parameters: dict[str, any]
|
193 |
+
optimizer: torch.optim.Optimizer | None = None
|
194 |
source_workspace: str | None = None
|
195 |
trained: bool = False
|
196 |
|
197 |
+
def __post_init__(self):
|
198 |
+
self._make_optimizer()
|
199 |
+
|
200 |
def num_parameters(self) -> int:
|
201 |
return sum(p.numel() for p in self.model.parameters())
|
202 |
|
|
|
226 |
self.optimizer.step()
|
227 |
return loss.item()
|
228 |
|
229 |
+
def _make_optimizer(self):
|
230 |
+
# We need to make a new optimizer when the model is copied. (It's tied to its parameters.)
|
231 |
+
p = self.optimizer_parameters
|
232 |
+
o = getattr(torch.optim, p["type"].name)
|
233 |
+
self.optimizer = o(self.model.parameters(), lr=p["lr"])
|
234 |
+
|
235 |
def copy(self):
|
236 |
"""Returns a copy of the model."""
|
237 |
+
c = dataclasses.replace(
|
238 |
+
self,
|
239 |
+
model=copy.deepcopy(self.model),
|
240 |
+
)
|
241 |
+
c._make_optimizer()
|
242 |
+
c.optimizer.load_state_dict(self.optimizer.state_dict())
|
243 |
return c
|
244 |
|
245 |
def metadata(self):
|
|
|
465 |
assert not list(cfg["loss"].parameters()), f"loss should have no parameters: {loss_layers}"
|
466 |
# Create optimizer.
|
467 |
op = self.catalog["Optimizer"]
|
468 |
+
cfg["optimizer_parameters"] = op.convert_params(self.nodes[self.optimizer].data.params)
|
|
|
|
|
469 |
return ModelConfig(**cfg)
|
470 |
|
471 |
|