Kohaku-Blueleaf commited on
Commit
2bc3bdc
·
1 Parent(s): 517b379
Files changed (1) hide show
  1. app.py +2 -15
app.py CHANGED
@@ -129,22 +129,9 @@ def cfg_wrapper(
129
  emb = te(**prompt_token).last_hidden_state
130
  neg_emb = te(**neg_prompt_token).last_hidden_state
131
 
132
- if emb.size(1) > neg_emb.size(1):
133
- pad_setting = (0, 0, 0, emb.size(1) - neg_emb.size(1))
134
- neg_emb = F.pad(neg_emb, pad_setting)
135
- if neg_emb.size(1) > emb.size(1):
136
- pad_setting = (0, 0, 0, neg_emb.size(1) - emb.size(1))
137
- emb = F.pad(emb, pad_setting)
138
- text_ctx_emb = torch.concat([emb, neg_emb])
139
-
140
  def cfg_fn(x, t, cfg=cfg_scale):
141
- cond, uncond = unet(
142
- x.repeat(2, 1, 1, 1),
143
- t.expand(x.size(0) * 2),
144
- text_ctx_emb,
145
- ).chunk(2)
146
- cond = cond.float()
147
- uncond = uncond.float()
148
  return uncond + (cond - uncond) * cfg
149
 
150
  return cfg_fn
 
129
  emb = te(**prompt_token).last_hidden_state
130
  neg_emb = te(**neg_prompt_token).last_hidden_state
131
 
 
 
 
 
 
 
 
 
132
  def cfg_fn(x, t, cfg=cfg_scale):
133
+ cond = unet(x, t.expand(x.size(0)), emb).float()
134
+ uncond = unet(x, t.expand(x.size(0)), neg_emb).float()
 
 
 
 
 
135
  return uncond + (cond - uncond) * cfg
136
 
137
  return cfg_fn