Kohaku-Blueleaf
commited on
Commit
·
2bc3bdc
1
Parent(s):
517b379
dont pad
Browse files
app.py
CHANGED
@@ -129,22 +129,9 @@ def cfg_wrapper(
|
|
129 |
emb = te(**prompt_token).last_hidden_state
|
130 |
neg_emb = te(**neg_prompt_token).last_hidden_state
|
131 |
|
132 |
-
if emb.size(1) > neg_emb.size(1):
|
133 |
-
pad_setting = (0, 0, 0, emb.size(1) - neg_emb.size(1))
|
134 |
-
neg_emb = F.pad(neg_emb, pad_setting)
|
135 |
-
if neg_emb.size(1) > emb.size(1):
|
136 |
-
pad_setting = (0, 0, 0, neg_emb.size(1) - emb.size(1))
|
137 |
-
emb = F.pad(emb, pad_setting)
|
138 |
-
text_ctx_emb = torch.concat([emb, neg_emb])
|
139 |
-
|
140 |
def cfg_fn(x, t, cfg=cfg_scale):
|
141 |
-
cond
|
142 |
-
|
143 |
-
t.expand(x.size(0) * 2),
|
144 |
-
text_ctx_emb,
|
145 |
-
).chunk(2)
|
146 |
-
cond = cond.float()
|
147 |
-
uncond = uncond.float()
|
148 |
return uncond + (cond - uncond) * cfg
|
149 |
|
150 |
return cfg_fn
|
|
|
129 |
emb = te(**prompt_token).last_hidden_state
|
130 |
neg_emb = te(**neg_prompt_token).last_hidden_state
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
def cfg_fn(x, t, cfg=cfg_scale):
|
133 |
+
cond = unet(x, t.expand(x.size(0)), emb).float()
|
134 |
+
uncond = unet(x, t.expand(x.size(0)), neg_emb).float()
|
|
|
|
|
|
|
|
|
|
|
135 |
return uncond + (cond - uncond) * cfg
|
136 |
|
137 |
return cfg_fn
|