Spaces:
Running
Running
Update model/SUNet_detail.py
Browse files- model/SUNet_detail.py +1 -24
model/SUNet_detail.py
CHANGED
|
@@ -3,7 +3,7 @@ import torch.nn as nn
|
|
| 3 |
import torch.utils.checkpoint as checkpoint
|
| 4 |
from einops import rearrange
|
| 5 |
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 6 |
-
|
| 7 |
|
| 8 |
class Mlp(nn.Module):
|
| 9 |
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
|
@@ -763,26 +763,3 @@ class SUNet(nn.Module):
|
|
| 763 |
flops += self.num_features * self.out_chans
|
| 764 |
return flops
|
| 765 |
|
| 766 |
-
|
| 767 |
-
if __name__ == '__main__':
|
| 768 |
-
from utils.model_utils import network_parameters
|
| 769 |
-
|
| 770 |
-
height = 256
|
| 771 |
-
width = 256
|
| 772 |
-
x = torch.randn((1, 3, height, width)) # .cuda()
|
| 773 |
-
model = SUNet(img_size=256, patch_size=4, in_chans=3, out_chans=3,
|
| 774 |
-
embed_dim=96, depths=[8, 8, 8, 8],
|
| 775 |
-
num_heads=[8, 8, 8, 8],
|
| 776 |
-
window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=2,
|
| 777 |
-
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
| 778 |
-
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
| 779 |
-
use_checkpoint=False, final_upsample="Dual up-sample") # .cuda()
|
| 780 |
-
# print(model)
|
| 781 |
-
print('input image size: (%d, %d)' % (height, width))
|
| 782 |
-
print('FLOPs: %.4f G' % (model.flops() / 1e9))
|
| 783 |
-
print('model parameters: ', network_parameters(model))
|
| 784 |
-
# x = model(x)
|
| 785 |
-
print('output image size: ', x.shape)
|
| 786 |
-
flops, params = profile(model, (x,))
|
| 787 |
-
print(flops)
|
| 788 |
-
print(params)
|
|
|
|
| 3 |
import torch.utils.checkpoint as checkpoint
|
| 4 |
from einops import rearrange
|
| 5 |
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 6 |
+
|
| 7 |
|
| 8 |
class Mlp(nn.Module):
|
| 9 |
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
|
|
|
| 763 |
flops += self.num_features * self.out_chans
|
| 764 |
return flops
|
| 765 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|