vobecant
commited on
Commit
·
179cb5d
1
Parent(s):
bd42ce3
Initial commit.
Browse files- .idea/workspace.xml +10 -4
- segmenter_model/factory.py +10 -16
.idea/workspace.xml
CHANGED
|
@@ -2,8 +2,7 @@
|
|
| 2 |
<project version="4">
|
| 3 |
<component name="ChangeListManager">
|
| 4 |
<list default="true" id="5dd22f22-8223-4d55-99f9-57d1e00622d7" name="Default Changelist" comment="Initial commit.">
|
| 5 |
-
<change beforePath="$PROJECT_DIR
|
| 6 |
-
<change beforePath="$PROJECT_DIR$/app.py" beforeDir="false" afterPath="$PROJECT_DIR$/app.py" afterDir="false" />
|
| 7 |
</list>
|
| 8 |
<option name="SHOW_DIALOG" value="false" />
|
| 9 |
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
|
@@ -51,7 +50,7 @@
|
|
| 51 |
<option name="number" value="Default" />
|
| 52 |
<option name="presentableId" value="Default" />
|
| 53 |
<updated>1647350746642</updated>
|
| 54 |
-
<workItem from="1647350750956" duration="
|
| 55 |
</task>
|
| 56 |
<task id="LOCAL-00001" summary="Initial commit.">
|
| 57 |
<created>1647352693910</created>
|
|
@@ -137,7 +136,14 @@
|
|
| 137 |
<option name="project" value="LOCAL" />
|
| 138 |
<updated>1647356274640</updated>
|
| 139 |
</task>
|
| 140 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
<servers />
|
| 142 |
</component>
|
| 143 |
<component name="TypeScriptGeneratedFilesManager">
|
|
|
|
| 2 |
<project version="4">
|
| 3 |
<component name="ChangeListManager">
|
| 4 |
<list default="true" id="5dd22f22-8223-4d55-99f9-57d1e00622d7" name="Default Changelist" comment="Initial commit.">
|
| 5 |
+
<change beforePath="$PROJECT_DIR$/segmenter_model/factory.py" beforeDir="false" afterPath="$PROJECT_DIR$/segmenter_model/factory.py" afterDir="false" />
|
|
|
|
| 6 |
</list>
|
| 7 |
<option name="SHOW_DIALOG" value="false" />
|
| 8 |
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
|
|
|
| 50 |
<option name="number" value="Default" />
|
| 51 |
<option name="presentableId" value="Default" />
|
| 52 |
<updated>1647350746642</updated>
|
| 53 |
+
<workItem from="1647350750956" duration="5731000" />
|
| 54 |
</task>
|
| 55 |
<task id="LOCAL-00001" summary="Initial commit.">
|
| 56 |
<created>1647352693910</created>
|
|
|
|
| 136 |
<option name="project" value="LOCAL" />
|
| 137 |
<updated>1647356274640</updated>
|
| 138 |
</task>
|
| 139 |
+
<task id="LOCAL-00013" summary="Initial commit.">
|
| 140 |
+
<created>1647356326582</created>
|
| 141 |
+
<option name="number" value="00013" />
|
| 142 |
+
<option name="presentableId" value="LOCAL-00013" />
|
| 143 |
+
<option name="project" value="LOCAL" />
|
| 144 |
+
<updated>1647356326582</updated>
|
| 145 |
+
</task>
|
| 146 |
+
<option name="localTasksCounter" value="14" />
|
| 147 |
<servers />
|
| 148 |
</component>
|
| 149 |
<component name="TypeScriptGeneratedFilesManager">
|
segmenter_model/factory.py
CHANGED
|
@@ -1,18 +1,17 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
-
import yaml
|
| 3 |
-
import torch
|
| 4 |
-
import math
|
| 5 |
import os
|
| 6 |
-
|
| 7 |
|
|
|
|
|
|
|
| 8 |
from timm.models.helpers import load_pretrained, load_custom_pretrained
|
| 9 |
-
from timm.models.vision_transformer import default_cfgs, checkpoint_filter_fn
|
| 10 |
from timm.models.registry import register_model
|
| 11 |
from timm.models.vision_transformer import _create_vision_transformer
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
from segmenter_model.decoder import MaskTransformer
|
| 13 |
from segmenter_model.segmenter import Segmenter
|
| 14 |
-
import segmenter_model.torch as ptu
|
| 15 |
-
|
| 16 |
from segmenter_model.vit_dino import vit_small, VisionTransformer
|
| 17 |
|
| 18 |
|
|
@@ -48,14 +47,9 @@ def create_vit(model_cfg):
|
|
| 48 |
model_cfg['drop_rate'] = model_cfg['dropout']
|
| 49 |
model = vit_small(**model_cfg)
|
| 50 |
# hard-coded for now, too lazy
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
pretrained_weights = ciirc_path
|
| 55 |
-
elif os.path.exists(karolina_path):
|
| 56 |
-
pretrained_weights = karolina_path
|
| 57 |
-
else:
|
| 58 |
-
raise Exception('DINO weights not found!')
|
| 59 |
model.load_state_dict(torch.load(pretrained_weights), strict=True)
|
| 60 |
else:
|
| 61 |
model = torch.hub.load('facebookresearch/dino:main', backbone)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
|
| 4 |
+
import requests
|
| 5 |
+
import yaml
|
| 6 |
from timm.models.helpers import load_pretrained, load_custom_pretrained
|
|
|
|
| 7 |
from timm.models.registry import register_model
|
| 8 |
from timm.models.vision_transformer import _create_vision_transformer
|
| 9 |
+
from timm.models.vision_transformer import default_cfgs, checkpoint_filter_fn
|
| 10 |
+
|
| 11 |
+
import segmenter_model.torch as ptu
|
| 12 |
+
import torch
|
| 13 |
from segmenter_model.decoder import MaskTransformer
|
| 14 |
from segmenter_model.segmenter import Segmenter
|
|
|
|
|
|
|
| 15 |
from segmenter_model.vit_dino import vit_small, VisionTransformer
|
| 16 |
|
| 17 |
|
|
|
|
| 47 |
model_cfg['drop_rate'] = model_cfg['dropout']
|
| 48 |
model = vit_small(**model_cfg)
|
| 49 |
# hard-coded for now, too lazy
|
| 50 |
+
pretrained_weights = 'dino_deitsmall16_pretrain.pth'
|
| 51 |
+
if not os.path.exists(pretrained_weights):
|
| 52 |
+
requests.get(pretrained_weights, allow_redirects=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
model.load_state_dict(torch.load(pretrained_weights), strict=True)
|
| 54 |
else:
|
| 55 |
model = torch.hub.load('facebookresearch/dino:main', backbone)
|