|
import os |
|
import ssl |
|
import shutil |
|
import tempfile |
|
import hashlib |
|
from tqdm import tqdm |
|
from torch.hub import get_dir |
|
from urllib.request import urlopen, Request |
|
|
|
from segmentation_models_pytorch.encoders import ( |
|
resnet_encoders, dpn_encoders, vgg_encoders, senet_encoders, |
|
densenet_encoders, inceptionresnetv2_encoders, inceptionv4_encoders, |
|
efficient_net_encoders, mobilenet_encoders, xception_encoders, |
|
timm_efficientnet_encoders, timm_resnest_encoders, timm_res2net_encoders, |
|
timm_regnet_encoders, timm_sknet_encoders, timm_mobilenetv3_encoders, |
|
timm_gernet_encoders |
|
) |
|
|
|
from segmentation_models_pytorch.encoders.timm_universal import TimmUniversalEncoder |
|
|
|
def initialize_encoders(): |
|
"""Initialize dictionary of available encoders.""" |
|
available_encoders = {} |
|
encoder_modules = [ |
|
resnet_encoders, dpn_encoders, vgg_encoders, senet_encoders, |
|
densenet_encoders, inceptionresnetv2_encoders, inceptionv4_encoders, |
|
efficient_net_encoders, mobilenet_encoders, xception_encoders, |
|
timm_efficientnet_encoders, timm_resnest_encoders, timm_res2net_encoders, |
|
timm_regnet_encoders, timm_sknet_encoders, timm_mobilenetv3_encoders, |
|
timm_gernet_encoders |
|
] |
|
|
|
for module in encoder_modules: |
|
available_encoders.update(module) |
|
|
|
try: |
|
import segmentation_models_pytorch |
|
from packaging import version |
|
if version.parse(segmentation_models_pytorch.__version__) >= version.parse("0.3.3"): |
|
from segmentation_models_pytorch.encoders.mix_transformer import mix_transformer_encoders |
|
from segmentation_models_pytorch.encoders.mobileone import mobileone_encoders |
|
available_encoders.update(mix_transformer_encoders) |
|
available_encoders.update(mobileone_encoders) |
|
except ImportError: |
|
pass |
|
|
|
return available_encoders |
|
|
|
def download_weights(url, destination, hash_prefix=None, show_progress=True): |
|
"""Downloads model weights with progress tracking and verification.""" |
|
ssl._create_default_https_context = ssl._create_unverified_context |
|
|
|
req = Request(url, headers={"User-Agent": "torch.hub"}) |
|
response = urlopen(req) |
|
content_length = response.headers.get("Content-Length") |
|
file_size = int(content_length[0]) if content_length else None |
|
|
|
destination = os.path.expanduser(destination) |
|
temp_file = tempfile.NamedTemporaryFile(delete=False, dir=os.path.dirname(destination)) |
|
|
|
try: |
|
hasher = hashlib.sha256() if hash_prefix else None |
|
|
|
with tqdm(total=file_size, disable=not show_progress, |
|
unit='B', unit_scale=True, unit_divisor=1024) as pbar: |
|
while True: |
|
buffer = response.read(8192) |
|
if not buffer: |
|
break |
|
|
|
temp_file.write(buffer) |
|
if hasher: |
|
hasher.update(buffer) |
|
pbar.update(len(buffer)) |
|
|
|
temp_file.close() |
|
|
|
if hasher and hash_prefix: |
|
digest = hasher.hexdigest() |
|
if digest[:len(hash_prefix)] != hash_prefix: |
|
raise RuntimeError(f'Invalid hash value (expected "{hash_prefix}", got "{digest}")') |
|
|
|
shutil.move(temp_file.name, destination) |
|
|
|
finally: |
|
temp_file.close() |
|
if os.path.exists(temp_file.name): |
|
os.remove(temp_file.name) |
|
|
|
def initialize_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs): |
|
"""Initializes and returns configured encoder.""" |
|
encoders = initialize_encoders() |
|
|
|
if name.startswith("tu-"): |
|
name = name[3:] |
|
return TimmUniversalEncoder( |
|
name=name, |
|
in_channels=in_channels, |
|
depth=depth, |
|
output_stride=output_stride, |
|
pretrained=weights is not None, |
|
**kwargs |
|
) |
|
|
|
try: |
|
encoder_config = encoders[name] |
|
except KeyError: |
|
raise KeyError(f"Invalid encoder name '{name}'. Available encoders: {list(encoders.keys())}") |
|
|
|
encoder_class = encoder_config["encoder"] |
|
encoder_params = encoder_config["params"] |
|
encoder_params.update(depth=depth) |
|
|
|
if weights: |
|
try: |
|
weights_config = encoder_config["pretrained_settings"][weights] |
|
except KeyError: |
|
raise KeyError( |
|
f"Invalid weights '{weights}' for encoder '{name}'. " |
|
f"Available options: {list(encoder_config['pretrained_settings'].keys())}" |
|
) |
|
|
|
cache_dir = os.path.join(get_dir(), 'checkpoints') |
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
weights_file = os.path.basename(weights_config["url"]) |
|
weights_path = os.path.join(cache_dir, weights_file) |
|
|
|
if not os.path.exists(weights_path): |
|
print(f'Downloading {weights_file}...') |
|
download_weights( |
|
weights_config["url"].replace("https", "http"), |
|
weights_path |
|
) |
|
|
|
return encoder_class(**encoder_params) |