File size: 1,614 Bytes
6da2a44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
from timm.loss import SoftTargetCrossEntropy
from timm.models.layers import drop
from models.network.hymamba import Encoder
# log more
def _ex_repr(self):
return ', '.join(
f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v))
for k, v in vars(self).items()
if not k.startswith('_') and k != 'training'
and not isinstance(v, (torch.nn.Module, torch.Tensor))
)
for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, drop.DropPath):
if hasattr(clz, 'extra_repr'):
clz.extra_repr = _ex_repr
else:
clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})'
pretrain_default_model_kwargs = {
'mambamim': dict(sparse=True, drop_path_rate=0.1),
}
for kw in pretrain_default_model_kwargs.values():
kw['pretrained'] = False
kw['num_classes'] = 0
kw['global_pool'] = ''
def build_sparse_encoder(name: str, input_size: int, sbn=False, drop_path_rate=0.0, verbose=False):
from models.encoder import SparseEncoder
kwargs = pretrain_default_model_kwargs[name]
if drop_path_rate != 0:
kwargs['drop_path_rate'] = drop_path_rate
print(f'[build_sparse_encoder] model kwargs={kwargs}')
encoder = Encoder(
in_channel=1,
channels=(32, 64, 128, 192, 384),
depths=(1, 2, 2, 2, 1),
kernels=(3, 3, 3, 3, 3),
exp_r=(2, 2, 4, 4, 4),
img_size=96,
depth=4,
sparse=True)
return SparseEncoder(encoder=encoder, input_size=input_size, sbn=sbn, verbose=verbose) |