| import flax.linen as nn | |
| from .convnext import ConvNeXt | |
| from .swin_ir import SwinIR | |
| def build_tail(size: str): | |
| """ Convenience function to build the three tails described in the paper. """ | |
| if size == 'air': | |
| return lambda x, _: x | |
| elif size == 'plus': | |
| blocks = [(64, 3, True)] * 6 + [(96, 3, True)] * 7 + [(128, 3, True)] * 3 | |
| return ConvNeXt(blocks) | |
| elif size == 'pro': | |
| return SwinIR(depths=[7, 6], num_heads=[6, 6]) | |
| else: | |
| raise NotImplementedError('size: ' + size) | |