File size: 2,765 Bytes
7b8d670
 
 
 
dc4b749
 
c7383ff
7b8d670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc4b749
7b8d670
 
c7383ff
7b8d670
 
c7383ff
 
 
7b8d670
dc4b749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import requests
from datetime import datetime
from email.utils import parsedate_to_datetime, formatdate
from DeepDeformationMapRegistration.utils.constants import ANATOMIES, MODEL_TYPES, ENCODER_FILTERS, DECODER_FILTERS, IMG_SHAPE
import voxelmorph as vxm
from DeepDeformationMapRegistration.utils.logger import LOGGER


# taken from: https://lenon.dev/blog/downloading-and-caching-large-files-using-python/
def download(url, destination_file):
    headers = {}

    if os.path.exists(destination_file):
        mtime = os.path.getmtime(destination_file)
        headers["if-modified-since"] = formatdate(mtime, usegmt=True)

    response = requests.get(url, headers=headers, stream=True)
    response.raise_for_status()

    if response.status_code == requests.codes.not_modified:
        return

    if response.status_code == requests.codes.ok:
        with open(destination_file, "wb") as f:
            for chunk in response.iter_content(chunk_size=1048576):
                f.write(chunk)

        last_modified = response.headers.get("last-modified")
        if last_modified:
            new_mtime = parsedate_to_datetime(last_modified).timestamp()
            os.utime(destination_file, times=(datetime.now().timestamp(), new_mtime))


def get_models_path(anatomy: str, model_type: str, output_root_dir: str):
    assert anatomy in ANATOMIES.keys(), 'Invalid anatomy'
    assert model_type in MODEL_TYPES.keys(), 'Invalid model type'
    anatomy = ANATOMIES[anatomy]
    model_type = MODEL_TYPES[model_type]
    url = 'https://github.com/jpdefrutos/DDMR/releases/download/trained_models_v0/' + anatomy + '_' + model_type + '.h5'
    file_path = os.path.join(output_root_dir, 'models', anatomy, model_type + '.h5')
    if not os.path.exists(file_path):
        LOGGER.info(f'Model not found. Downloading from {url}... ')
        os.makedirs(os.path.split(file_path)[0], exist_ok=True)
        download(url, file_path)
        LOGGER.info(f'... downloaded model. Stored in {file_path}')
    else:
        LOGGER.info(f'Found model: {file_path}')
    return file_path


def load_model(weights_file_path: str, trainable: bool = False, return_registration_model: bool=True):
    assert os.path.exists(weights_file_path), f'File {weights_file_path} not found'
    assert weights_file_path.endswith('h5'), 'Invalid file extension. Expected .h5'

    ret_val = vxm.networks.VxmDense(inshape=IMG_SHAPE[:-1],
                                    nb_unet_features=[ENCODER_FILTERS, DECODER_FILTERS],
                                    int_steps=0)
    ret_val.load_weights(weights_file_path, by_name=True)
    ret_val.trainable = trainable

    if return_registration_model:
        ret_val = (ret_val, ret_val.get_registration_model())

    return ret_val