Pusheen commited on
Commit
5060186
·
verified ·
1 Parent(s): e39222b

Update gligen/ldm/models/diffusion/plms.py

Browse files
gligen/ldm/models/diffusion/plms.py CHANGED
@@ -142,6 +142,7 @@ class PLMSSampler(object):
142
  input["timesteps"] = ts
143
 
144
  # print("optimize", index1)
 
145
  while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
146
  # print('iter', iteration)
147
  x = x.requires_grad_(True)
 
142
  input["timesteps"] = ts
143
 
144
  # print("optimize", index1)
145
+ self.model.train()
146
  while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
147
  # print('iter', iteration)
148
  x = x.requires_grad_(True)