Spaces:
Running
Running
VAE: Support retuning intermediate features for 3d perceptual loss
Browse files
xora/models/autoencoders/video_autoencoder.py
CHANGED
|
@@ -310,7 +310,9 @@ class Encoder(nn.Module):
|
|
| 310 |
* self.patch_size
|
| 311 |
)
|
| 312 |
|
| 313 |
-
def forward(
|
|
|
|
|
|
|
| 314 |
r"""The forward method of the `Encoder` class."""
|
| 315 |
|
| 316 |
downsample_in_time = sample.shape[2] != 1
|
|
@@ -332,10 +334,14 @@ class Encoder(nn.Module):
|
|
| 332 |
else lambda x: x
|
| 333 |
)
|
| 334 |
|
|
|
|
|
|
|
| 335 |
for down_block in self.down_blocks:
|
| 336 |
sample = checkpoint_fn(down_block)(
|
| 337 |
sample, downsample_in_time=downsample_in_time
|
| 338 |
)
|
|
|
|
|
|
|
| 339 |
|
| 340 |
sample = checkpoint_fn(self.mid_block)(sample)
|
| 341 |
|
|
@@ -363,6 +369,11 @@ class Encoder(nn.Module):
|
|
| 363 |
else:
|
| 364 |
raise ValueError(f"Invalid input shape: {sample.shape}")
|
| 365 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
return sample
|
| 367 |
|
| 368 |
|
|
|
|
| 310 |
* self.patch_size
|
| 311 |
)
|
| 312 |
|
| 313 |
+
def forward(
|
| 314 |
+
self, sample: torch.FloatTensor, return_features=False
|
| 315 |
+
) -> torch.FloatTensor:
|
| 316 |
r"""The forward method of the `Encoder` class."""
|
| 317 |
|
| 318 |
downsample_in_time = sample.shape[2] != 1
|
|
|
|
| 334 |
else lambda x: x
|
| 335 |
)
|
| 336 |
|
| 337 |
+
if return_features:
|
| 338 |
+
features = []
|
| 339 |
for down_block in self.down_blocks:
|
| 340 |
sample = checkpoint_fn(down_block)(
|
| 341 |
sample, downsample_in_time=downsample_in_time
|
| 342 |
)
|
| 343 |
+
if return_features:
|
| 344 |
+
features.append(sample)
|
| 345 |
|
| 346 |
sample = checkpoint_fn(self.mid_block)(sample)
|
| 347 |
|
|
|
|
| 369 |
else:
|
| 370 |
raise ValueError(f"Invalid input shape: {sample.shape}")
|
| 371 |
|
| 372 |
+
if return_features:
|
| 373 |
+
features.append(
|
| 374 |
+
sample[:, sample.shape[1] // 2, ...]
|
| 375 |
+
) # Add the latent means as final feature
|
| 376 |
+
return sample, features
|
| 377 |
return sample
|
| 378 |
|
| 379 |
|