Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
def pad_to_22_channels(input_tensor): | |
if input_tensor.shape[1] == 3: # RGB input | |
return torch.cat([input_tensor]*7 + [input_tensor[:,0:1]], dim=1) | |
return input_tensor |