|
import torch.nn as nn |
|
|
|
|
|
def change_num_input_channels(model, in_channels=1): |
|
""" |
|
Assumes number of input channels in model is 3. |
|
""" |
|
for i, m in enumerate(model.modules()): |
|
if isinstance(m, (nn.Conv2d,nn.Conv3d)) and m.in_channels == 3: |
|
m.in_channels = in_channels |
|
|
|
W = m.weight.sum(1, keepdim=True) |
|
|
|
W = W / in_channels |
|
|
|
size = [1] * W.ndim |
|
size[1] = in_channels |
|
W = W.repeat(size) |
|
m.weight = nn.Parameter(W) |
|
break |
|
return model |
|
|
|
|
|
def change_initial_stride(model, stride, in_channels): |
|
|
|
for i, m in enumerate(model.modules()): |
|
if isinstance(m, (nn.Conv2d, nn.Conv3d)) and m.in_channels == in_channels: |
|
m.stride = stride |
|
break |
|
return model |