darabos commited on
Commit
a961ac6
·
1 Parent(s): f41635e

Make new optimizer when model is copied.

Browse files
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch_model_ops.py CHANGED
@@ -189,10 +189,14 @@ class ModelConfig:
189
  model_outputs: list[str]
190
  loss_inputs: list[str]
191
  loss: torch.nn.Module
192
- optimizer: torch.optim.Optimizer
 
193
  source_workspace: str | None = None
194
  trained: bool = False
195
 
 
 
 
196
  def num_parameters(self) -> int:
197
  return sum(p.numel() for p in self.model.parameters())
198
 
@@ -222,10 +226,20 @@ class ModelConfig:
222
  self.optimizer.step()
223
  return loss.item()
224
 
 
 
 
 
 
 
225
  def copy(self):
226
  """Returns a copy of the model."""
227
- c = dataclasses.replace(self)
228
- c.model = copy.deepcopy(self.model)
 
 
 
 
229
  return c
230
 
231
  def metadata(self):
@@ -451,9 +465,7 @@ class ModelBuilder:
451
  assert not list(cfg["loss"].parameters()), f"loss should have no parameters: {loss_layers}"
452
  # Create optimizer.
453
  op = self.catalog["Optimizer"]
454
- p = op.convert_params(self.nodes[self.optimizer].data.params)
455
- o = getattr(torch.optim, p["type"].name)
456
- cfg["optimizer"] = o(cfg["model"].parameters(), lr=p["lr"])
457
  return ModelConfig(**cfg)
458
 
459
 
 
189
  model_outputs: list[str]
190
  loss_inputs: list[str]
191
  loss: torch.nn.Module
192
+ optimizer_parameters: dict[str, any]
193
+ optimizer: torch.optim.Optimizer | None = None
194
  source_workspace: str | None = None
195
  trained: bool = False
196
 
197
+ def __post_init__(self):
198
+ self._make_optimizer()
199
+
200
  def num_parameters(self) -> int:
201
  return sum(p.numel() for p in self.model.parameters())
202
 
 
226
  self.optimizer.step()
227
  return loss.item()
228
 
229
+ def _make_optimizer(self):
230
+ # We need to make a new optimizer when the model is copied. (It's tied to its parameters.)
231
+ p = self.optimizer_parameters
232
+ o = getattr(torch.optim, p["type"].name)
233
+ self.optimizer = o(self.model.parameters(), lr=p["lr"])
234
+
235
  def copy(self):
236
  """Returns a copy of the model."""
237
+ c = dataclasses.replace(
238
+ self,
239
+ model=copy.deepcopy(self.model),
240
+ )
241
+ c._make_optimizer()
242
+ c.optimizer.load_state_dict(self.optimizer.state_dict())
243
  return c
244
 
245
  def metadata(self):
 
465
  assert not list(cfg["loss"].parameters()), f"loss should have no parameters: {loss_layers}"
466
  # Create optimizer.
467
  op = self.catalog["Optimizer"]
468
+ cfg["optimizer_parameters"] = op.convert_params(self.nodes[self.optimizer].data.params)
 
 
469
  return ModelConfig(**cfg)
470
 
471