Spaces:
Runtime error
Runtime error
bugfix model.py
Browse files- OmniGen/model.py +3 -3
OmniGen/model.py
CHANGED
|
@@ -347,7 +347,7 @@ class OmniGen(nn.Module, PeftAdapterMixin):
|
|
| 347 |
x = self.final_layer(image_embedding, time_emb)
|
| 348 |
latents = self.unpatchify(x, shapes[0], shapes[1])
|
| 349 |
|
| 350 |
-
if
|
| 351 |
return latents, past_key_values
|
| 352 |
return latents
|
| 353 |
|
|
@@ -357,7 +357,7 @@ class OmniGen(nn.Module, PeftAdapterMixin):
|
|
| 357 |
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
| 358 |
"""
|
| 359 |
self.llm.config.use_cache = use_kv_cache
|
| 360 |
-
model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values)
|
| 361 |
if use_img_cfg:
|
| 362 |
cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
|
| 363 |
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
|
@@ -371,7 +371,7 @@ class OmniGen(nn.Module, PeftAdapterMixin):
|
|
| 371 |
|
| 372 |
|
| 373 |
@torch.no_grad()
|
| 374 |
-
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache):
|
| 375 |
"""
|
| 376 |
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
| 377 |
"""
|
|
|
|
| 347 |
x = self.final_layer(image_embedding, time_emb)
|
| 348 |
latents = self.unpatchify(x, shapes[0], shapes[1])
|
| 349 |
|
| 350 |
+
if return_past_key_values:
|
| 351 |
return latents, past_key_values
|
| 352 |
return latents
|
| 353 |
|
|
|
|
| 357 |
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
| 358 |
"""
|
| 359 |
self.llm.config.use_cache = use_kv_cache
|
| 360 |
+
model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True)
|
| 361 |
if use_img_cfg:
|
| 362 |
cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
|
| 363 |
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
|
|
|
| 371 |
|
| 372 |
|
| 373 |
@torch.no_grad()
|
| 374 |
+
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, return_past_key_values=True):
|
| 375 |
"""
|
| 376 |
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
| 377 |
"""
|