Spaces:
Runtime error
Runtime error
Update autoregressive/models/gpt_t2i.py
Browse files
autoregressive/models/gpt_t2i.py
CHANGED
|
@@ -430,7 +430,7 @@ class Transformer(nn.Module):
|
|
| 430 |
token_embeddings = self.cls_embedding(cond_idx, train=self.training)
|
| 431 |
token_embeddings = token_embeddings[:,:self.cls_token_num]
|
| 432 |
if condition is not None:
|
| 433 |
-
condition_embeddings = self.condition_mlp(condition.to(torch.bfloat16),train=self.training)
|
| 434 |
self.condition_token = condition_embeddings
|
| 435 |
|
| 436 |
else: # decode_n_tokens(kv cache) in inference
|
|
|
|
| 430 |
token_embeddings = self.cls_embedding(cond_idx, train=self.training)
|
| 431 |
token_embeddings = token_embeddings[:,:self.cls_token_num]
|
| 432 |
if condition is not None:
|
| 433 |
+
condition_embeddings = self.condition_mlp(condition,train=self.training)#.to(torch.bfloat16),train=self.training)
|
| 434 |
self.condition_token = condition_embeddings
|
| 435 |
|
| 436 |
else: # decode_n_tokens(kv cache) in inference
|