cantabile-kwok
prepare demo page
05005db
raw
history blame contribute delete
623 Bytes
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
transpose last 2 dimensions of the input
"""
import torch.nn as nn
class TransposeLast(nn.Module):
def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
super().__init__()
self.deconstruct_idx = deconstruct_idx
self.tranpose_dim = tranpose_dim
def forward(self, x):
if self.deconstruct_idx is not None:
x = x[self.deconstruct_idx]
return x.transpose(self.tranpose_dim, -1)