Pusheen commited on
Commit
976890a
·
verified ·
1 Parent(s): f3bf446

Update gligen/ldm/models/diffusion/plms.py

Browse files
Files changed (1) hide show
  1. gligen/ldm/models/diffusion/plms.py +11 -7
gligen/ldm/models/diffusion/plms.py CHANGED
@@ -1,6 +1,6 @@
1
- import os
2
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "0"
3
- os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
4
  import torch
5
  import numpy as np
6
  from tqdm import tqdm
@@ -192,13 +192,19 @@ class PLMSSampler(object):
192
 
193
  print("optimize", index1)
194
  self.model.train()
195
- torch.cuda.empty_cache()
196
  while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
197
  print('iter', iteration)
198
  # import pdb; pdb.set_trace()
199
  x = x.requires_grad_(True)
200
  input['x'] = x
201
  e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
 
 
 
 
 
 
202
  bboxes = input['boxes_att']
203
  object_positions = input['object_position']
204
  # loss1 = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes,
@@ -213,9 +219,7 @@ class PLMSSampler(object):
213
  del att_first
214
  del att_second
215
  del att_third
216
- del self_first
217
- del self_second
218
- del self_third
219
 
220
  # grad_cond = x.grad
221
  x = x - grad_cond
 
1
+ # import os
2
+ # os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "0"
3
+ # os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
4
  import torch
5
  import numpy as np
6
  from tqdm import tqdm
 
192
 
193
  print("optimize", index1)
194
  self.model.train()
195
+ # torch.cuda.empty_cache()
196
  while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
197
  print('iter', iteration)
198
  # import pdb; pdb.set_trace()
199
  x = x.requires_grad_(True)
200
  input['x'] = x
201
  e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
202
+
203
+ del self_first
204
+ del self_second
205
+ del self_third
206
+ del e_t
207
+
208
  bboxes = input['boxes_att']
209
  object_positions = input['object_position']
210
  # loss1 = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes,
 
219
  del att_first
220
  del att_second
221
  del att_third
222
+
 
 
223
 
224
  # grad_cond = x.grad
225
  x = x - grad_cond