580136d b314508
1
2
3
4
5
6
7
import torch import numpy as np def pad_to_22_channels(input_tensor): if input_tensor.shape[1] == 3: # RGB input return torch.cat([input_tensor]*7 + [input_tensor[:,0:1]], dim=1) return input_tensor