Pusheen commited on
Commit
a8019fa
·
verified ·
1 Parent(s): 8ecf30b

Update gligen/ldm/models/diffusion/plms.py

Browse files
Files changed (1) hide show
  1. gligen/ldm/models/diffusion/plms.py +68 -27
gligen/ldm/models/diffusion/plms.py CHANGED
@@ -3,9 +3,10 @@ import numpy as np
3
  from tqdm import tqdm
4
  from functools import partial
5
  from copy import deepcopy
 
6
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
7
  import math
8
- from ldm.models.diffusion.loss import caculate_loss_att_fixed_cnt, caculate_loss_self_att
9
  class PLMSSampler(object):
10
  def __init__(self, diffusion, model, schedule="linear", alpha_generator_func=None, set_alpha_scale=None):
11
  super().__init__()
@@ -57,14 +58,14 @@ class PLMSSampler(object):
57
 
58
 
59
  # @torch.no_grad()
60
- def sample(self, S, shape, input, uc=None, guidance_scale=1, mask=None, x0=None, loss_type='SAR_CAR'):
61
  self.make_schedule(ddim_num_steps=S)
62
  # import pdb; pdb.set_trace()
63
  return self.plms_sampling(shape, input, uc, guidance_scale, mask=mask, x0=x0, loss_type=loss_type)
64
 
65
 
66
  # @torch.no_grad()
67
- def plms_sampling(self, shape, input, uc=None, guidance_scale=1, mask=None, x0=None, loss_type='SAR_CAR'):
68
 
69
  b = shape[0]
70
 
@@ -81,6 +82,7 @@ class PLMSSampler(object):
81
  if self.alpha_generator_func != None:
82
  alphas = self.alpha_generator_func(len(time_range))
83
 
 
84
  for i, step in enumerate(time_range):
85
 
86
  # set alpha and restore first conv layer
@@ -102,12 +104,7 @@ class PLMSSampler(object):
102
  # three loss types
103
  if loss_type !=None and loss_type!='standard':
104
  if input['object_position'] != []:
105
- if loss_type=='SAR_CAR':
106
- x = self.update_loss_self_cross( input,i, index, ts )
107
- elif loss_type=='SAR':
108
- x = self.update_only_self( input,i, index, ts )
109
- elif loss_type=='CAR':
110
- x = self.update_loss_only_cross( input,i, index, ts )
111
  input["x"] = x
112
  img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc=uc, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next)
113
  input["x"] = img
@@ -119,11 +116,11 @@ class PLMSSampler(object):
119
 
120
  def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
121
  if index1 < 10:
122
- loss_scale = 4
123
- max_iter = 1
124
- elif index1 < 20:
125
  loss_scale = 3
126
- max_iter = 1
 
 
 
127
  else:
128
  loss_scale = 1
129
  max_iter = 1
@@ -136,29 +133,25 @@ class PLMSSampler(object):
136
  input["timesteps"] = ts
137
 
138
  print("optimize", index1)
139
- self.model.train()
140
  while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
141
  print('iter', iteration)
142
- # import pdb; pdb.set_trace()
143
  x = x.requires_grad_(True)
144
  input['x'] = x
145
  e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
146
- bboxes = input['boxes_att']
147
  object_positions = input['object_position']
148
  loss1 = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes,
149
  object_positions=object_positions, t = index1)*loss_scale
150
  loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes,
151
  object_positions=object_positions, t = index1)*loss_scale
152
  loss = loss1 + loss2
153
- print('loss', loss, loss1, loss2)
154
- # hh = torch.autograd.backward(loss, retain_graph=True)
155
- grad_cond = torch.autograd.grad(loss.requires_grad_(True), [x])[0]
156
- # grad_cond = x.grad
157
  x = x - grad_cond
158
  x = x.detach()
159
  iteration += 1
160
-
161
-
162
  return x
163
 
164
  def update_loss_only_cross(self, input,index1, index, ts,type_loss='self_accross'):
@@ -184,6 +177,7 @@ class PLMSSampler(object):
184
  while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
185
  print('iter', iteration)
186
  x = x.requires_grad_(True)
 
187
  input['x'] = x
188
  e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
189
 
@@ -193,7 +187,55 @@ class PLMSSampler(object):
193
  object_positions=object_positions, t = index1)*loss_scale
194
  loss = loss2
195
  print('loss', loss)
196
- hh = torch.autograd.backward(loss)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  grad_cond = x.grad
198
  x = x - grad_cond
199
  x = x.detach()
@@ -244,13 +286,12 @@ class PLMSSampler(object):
244
  def p_sample_plms(self, input, t, index, guidance_scale=1., uc=None, old_eps=None, t_next=None):
245
  x = deepcopy(input["x"])
246
  b = x.shape[0]
247
- self.model.eval()
248
  def get_model_output(input):
249
  e_t, first, second, third,_,_,_ = self.model(input)
250
  if uc is not None and guidance_scale != 1:
251
- unconditional_input = dict(x=input["x"], timesteps=input["timesteps"], context=uc, inpainting_extra_input=None, grounding_extra_input=None)
252
- # unconditional_input=input
253
- e_t_uncond, _, _, _, _, _, _ = self.model( unconditional_input)
254
  e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond)
255
  return e_t
256
 
 
3
  from tqdm import tqdm
4
  from functools import partial
5
  from copy import deepcopy
6
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler
7
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
8
  import math
9
+ from ldm.models.diffusion.loss import caculate_loss_att_fixed_cnt, caculate_loss_self_att, caculate_loss_LoCo_V2
10
  class PLMSSampler(object):
11
  def __init__(self, diffusion, model, schedule="linear", alpha_generator_func=None, set_alpha_scale=None):
12
  super().__init__()
 
58
 
59
 
60
  # @torch.no_grad()
61
+ def sample(self, S, shape, input, uc=None, guidance_scale=1, mask=None, x0=None, loss_type=None):
62
  self.make_schedule(ddim_num_steps=S)
63
  # import pdb; pdb.set_trace()
64
  return self.plms_sampling(shape, input, uc, guidance_scale, mask=mask, x0=x0, loss_type=loss_type)
65
 
66
 
67
  # @torch.no_grad()
68
+ def plms_sampling(self, shape, input, uc=None, guidance_scale=1, mask=None, x0=None, loss_type=None):
69
 
70
  b = shape[0]
71
 
 
82
  if self.alpha_generator_func != None:
83
  alphas = self.alpha_generator_func(len(time_range))
84
 
85
+
86
  for i, step in enumerate(time_range):
87
 
88
  # set alpha and restore first conv layer
 
104
  # three loss types
105
  if loss_type !=None and loss_type!='standard':
106
  if input['object_position'] != []:
107
+ x = self.update_loss_LoCo( input,i, index, ts, time_factor = time_factor)
 
 
 
 
 
108
  input["x"] = x
109
  img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc=uc, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next)
110
  input["x"] = img
 
116
 
117
  def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
118
  if index1 < 10:
 
 
 
119
  loss_scale = 3
120
+ max_iter = 5
121
+ elif index1 < 20:
122
+ loss_scale = 2
123
+ max_iter = 3
124
  else:
125
  loss_scale = 1
126
  max_iter = 1
 
133
  input["timesteps"] = ts
134
 
135
  print("optimize", index1)
 
136
  while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
137
  print('iter', iteration)
 
138
  x = x.requires_grad_(True)
139
  input['x'] = x
140
  e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
141
+ bboxes = input['boxes']
142
  object_positions = input['object_position']
143
  loss1 = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes,
144
  object_positions=object_positions, t = index1)*loss_scale
145
  loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes,
146
  object_positions=object_positions, t = index1)*loss_scale
147
  loss = loss1 + loss2
148
+ print('AR loss:', loss, 'SAR:', loss1, 'CAR:', loss2)
149
+ hh = torch.autograd.backward(loss)
150
+ grad_cond = x.grad
 
151
  x = x - grad_cond
152
  x = x.detach()
153
  iteration += 1
154
+ torch.cuda.empty_cache()
 
155
  return x
156
 
157
  def update_loss_only_cross(self, input,index1, index, ts,type_loss='self_accross'):
 
177
  while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
178
  print('iter', iteration)
179
  x = x.requires_grad_(True)
180
+ print('x shape', x.shape)
181
  input['x'] = x
182
  e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
183
 
 
187
  object_positions=object_positions, t = index1)*loss_scale
188
  loss = loss2
189
  print('loss', loss)
190
+ hh = torch.autograd.backward(loss, retain_graph=True)
191
+ grad_cond = x.grad
192
+ x = x - grad_cond
193
+ x = x.detach()
194
+ iteration += 1
195
+ torch.cuda.empty_cache()
196
+ return x
197
+
198
+ def update_loss_LoCo(self, input,index1, index, ts, time_factor, type_loss='self_accross'):
199
+
200
+ # loss_scale = 30
201
+ # max_iter = 5
202
+ #print('time_factor is: ', time_factor)
203
+ if index1 < 10:
204
+ loss_scale = 8
205
+ max_iter = 5
206
+ elif index1 < 20:
207
+ loss_scale = 5
208
+ max_iter = 5
209
+ else:
210
+ loss_scale = 1
211
+ max_iter = 1
212
+ loss_threshold = 0.1
213
+
214
+ max_index = 30
215
+ x = deepcopy(input["x"])
216
+ iteration = 0
217
+ loss = torch.tensor(10000)
218
+ input["timesteps"] = ts
219
+
220
+ # print("optimize", index1)
221
+ while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
222
+ # print('iter', iteration)
223
+ x = x.requires_grad_(True)
224
+ # print('x shape', x.shape)
225
+ input['x'] = x
226
+ e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
227
+
228
+ bboxes = input['boxes']
229
+ object_positions = input['object_position']
230
+ loss2 = caculate_loss_LoCo_V2(att_second,att_first,att_third, bboxes=bboxes,
231
+ object_positions=object_positions, t = index1)*loss_scale
232
+ # loss = loss2
233
+ # loss.requires_grad_(True)
234
+ #print('LoCo loss', loss)
235
+
236
+
237
+
238
+ hh = torch.autograd.backward(loss2, retain_graph=True)
239
  grad_cond = x.grad
240
  x = x - grad_cond
241
  x = x.detach()
 
286
  def p_sample_plms(self, input, t, index, guidance_scale=1., uc=None, old_eps=None, t_next=None):
287
  x = deepcopy(input["x"])
288
  b = x.shape[0]
289
+
290
  def get_model_output(input):
291
  e_t, first, second, third,_,_,_ = self.model(input)
292
  if uc is not None and guidance_scale != 1:
293
+ unconditional_input = dict(x=input["x"], timesteps=input["timesteps"], context=uc, inpainting_extra_input=input["inpainting_extra_input"], grounding_extra_input=input['grounding_extra_input'])
294
+ e_t_uncond, _, _, _, _, _, _ = self.model( unconditional_input )
 
295
  e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond)
296
  return e_t
297