Spaces:
Runtime error
Runtime error
Update autoregressive/models/generate.py
Browse files
autoregressive/models/generate.py
CHANGED
|
@@ -140,10 +140,12 @@ def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_int
|
|
| 140 |
condition = condition.to(torch.float32)
|
| 141 |
print(condition)
|
| 142 |
if condition is not None:
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
| 147 |
if model.model_type == 'c2i':
|
| 148 |
if cfg_scale > 1.0:
|
| 149 |
cond_null = torch.ones_like(cond) * model.num_classes
|
|
|
|
| 140 |
condition = condition.to(torch.float32)
|
| 141 |
print(condition)
|
| 142 |
if condition is not None:
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
print(model.adapter.model.embeddings.patch_embeddings.projection.weight)
|
| 145 |
+
condition = model.adapter(condition)
|
| 146 |
+
print(condition)
|
| 147 |
+
condition = model.adapter_mlp(condition)
|
| 148 |
+
print(condition)
|
| 149 |
if model.model_type == 'c2i':
|
| 150 |
if cfg_scale > 1.0:
|
| 151 |
cond_null = torch.ones_like(cond) * model.num_classes
|