Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| class Film(nn.Module): | |
| def __init__(self, channels, cond_embedding_dim): | |
| super(Film, self).__init__() | |
| self.linear = nn.Sequential( | |
| nn.Linear(cond_embedding_dim, channels * 2), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(channels * 2, channels), | |
| nn.ReLU(inplace=True) | |
| ) | |
| def forward(self, data, cond_vec): | |
| """ | |
| :param data: [batchsize, channels, samples] or [batchsize, channels, T, F] or [batchsize, channels, F, T] | |
| :param cond_vec: [batchsize, cond_embedding_dim] | |
| :return: | |
| """ | |
| bias = self.linear(cond_vec) # [batchsize, channels] | |
| if len(list(data.size())) == 3: | |
| data = data + bias[..., None] | |
| elif len(list(data.size())) == 4: | |
| data = data + bias[..., None, None] | |
| else: | |
| print("Warning: The size of input tensor,", data.size(), "is not correct. Film is not working.") | |
| return data |