Update transformer.py
Browse files- transformer.py +2 -2
transformer.py
CHANGED
@@ -138,6 +138,6 @@ class ViT(nn.Module):
|
|
138 |
|
139 |
x = self.encoder(x)
|
140 |
|
141 |
-
|
142 |
|
143 |
-
return
|
|
|
138 |
|
139 |
x = self.encoder(x)
|
140 |
|
141 |
+
output = self.output(x, return_z=return_z)
|
142 |
|
143 |
+
return output
|