Spaces:
Runtime error
Runtime error
Commit
·
cae6d82
1
Parent(s):
acd9841
yes
Browse files
app.py
CHANGED
|
@@ -111,7 +111,7 @@ def encode_sdxl_prompt(prompt, negative_prompt=""):
|
|
| 111 |
clip_l_embeds = pipe.text_encoder(tokens_l)[0]
|
| 112 |
neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0]
|
| 113 |
|
| 114 |
-
# CLIP-G embeddings (1280d)
|
| 115 |
clip_g_embeds = pipe.text_encoder_2(tokens_g)[0]
|
| 116 |
neg_clip_g_embeds = pipe.text_encoder_2(neg_tokens_g)[0]
|
| 117 |
|
|
@@ -143,14 +143,7 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
|
|
| 143 |
pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
|
| 144 |
|
| 145 |
# Get T5 embeddings for semantic understanding
|
| 146 |
-
t5_ids = t5_tok(
|
| 147 |
-
prompt,
|
| 148 |
-
return_tensors="pt",
|
| 149 |
-
padding="max_length",
|
| 150 |
-
max_length=77, # Match CLIP's standard length
|
| 151 |
-
truncation=True
|
| 152 |
-
).input_ids.to(device)
|
| 153 |
-
print(t5_ids.shape)
|
| 154 |
t5_seq = t5_mod(t5_ids).last_hidden_state
|
| 155 |
|
| 156 |
# Get proper SDXL CLIP embeddings
|
|
@@ -160,6 +153,19 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
|
|
| 160 |
adapter_l = load_adapter(repo_l, adapter_l_file, config_l) if adapter_l_file else None
|
| 161 |
adapter_g = load_adapter(repo_g, adapter_g_file, config_g) if adapter_g_file else None
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
# Apply CLIP-L adapter
|
| 164 |
if adapter_l is not None:
|
| 165 |
anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"])
|
|
@@ -187,6 +193,23 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
|
|
| 187 |
clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
|
| 188 |
if noise > 0:
|
| 189 |
clip_g_mod += torch.randn_like(clip_g_mod) * noise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
else:
|
| 191 |
clip_g_mod = clip_embeds["clip_g"]
|
| 192 |
delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
|
|
|
|
| 111 |
clip_l_embeds = pipe.text_encoder(tokens_l)[0]
|
| 112 |
neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0]
|
| 113 |
|
| 114 |
+
# CLIP-G embeddings (1280d) - get the hidden states [0], not pooled [1]
|
| 115 |
clip_g_embeds = pipe.text_encoder_2(tokens_g)[0]
|
| 116 |
neg_clip_g_embeds = pipe.text_encoder_2(neg_tokens_g)[0]
|
| 117 |
|
|
|
|
| 143 |
pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
|
| 144 |
|
| 145 |
# Get T5 embeddings for semantic understanding
|
| 146 |
+
t5_ids = t5_tok(prompt, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
t5_seq = t5_mod(t5_ids).last_hidden_state
|
| 148 |
|
| 149 |
# Get proper SDXL CLIP embeddings
|
|
|
|
| 153 |
adapter_l = load_adapter(repo_l, adapter_l_file, config_l) if adapter_l_file else None
|
| 154 |
adapter_g = load_adapter(repo_g, adapter_g_file, config_g) if adapter_g_file else None
|
| 155 |
|
| 156 |
+
# Ensure all embeddings have the same sequence length (77 tokens)
|
| 157 |
+
seq_len = 77
|
| 158 |
+
|
| 159 |
+
# Resize T5 to match CLIP sequence length
|
| 160 |
+
if t5_seq.size(1) != seq_len:
|
| 161 |
+
t5_seq = torch.nn.functional.interpolate(
|
| 162 |
+
t5_seq.transpose(1, 2),
|
| 163 |
+
size=seq_len,
|
| 164 |
+
mode="nearest"
|
| 165 |
+
).transpose(1, 2)
|
| 166 |
+
|
| 167 |
+
print(f"After resize - T5: {t5_seq.shape}, CLIP-L: {clip_embeds['clip_l'].shape}, CLIP-G: {clip_embeds['clip_g'].shape}")
|
| 168 |
+
|
| 169 |
# Apply CLIP-L adapter
|
| 170 |
if adapter_l is not None:
|
| 171 |
anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"])
|
|
|
|
| 193 |
clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
|
| 194 |
if noise > 0:
|
| 195 |
clip_g_mod += torch.randn_like(clip_g_mod) * noise
|
| 196 |
+
else:
|
| 197 |
+
clip_g_mod = clip_embeds["clip_g"]
|
| 198 |
+
delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
|
| 199 |
+
gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"])
|
| 200 |
+
g_pred_g = torch.tensor(0.0)
|
| 201 |
+
tau_g = torch.tensor(0.0) 2)
|
| 202 |
+
else:
|
| 203 |
+
t5_seq_resized = t5_seq
|
| 204 |
+
|
| 205 |
+
anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq_resized, clip_embeds["clip_g"])
|
| 206 |
+
gate_g_scaled = gate_g * gate_prob
|
| 207 |
+
delta_g_final = delta_g * strength * gate_g_scaled
|
| 208 |
+
clip_g_mod = clip_embeds["clip_g"] + delta_g_final
|
| 209 |
+
if use_anchor:
|
| 210 |
+
clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
|
| 211 |
+
if noise > 0:
|
| 212 |
+
clip_g_mod += torch.randn_like(clip_g_mod) * noise
|
| 213 |
else:
|
| 214 |
clip_g_mod = clip_embeds["clip_g"]
|
| 215 |
delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
|