jpdefrutos's picture
Refactor the model utils functions (download weight files, get weight files path, and load models)
dc4b749
raw
history blame
2.51 kB
import os
import requests
from datetime import datetime
from email.utils import parsedate_to_datetime, formatdate
from DeepDeformationMapRegistration.utils.constants import ANATOMIES, MODEL_TYPES, ENCODER_FILTERS, DECODER_FILTERS, IMG_SHAPE
import voxelmorph as vxm
# taken from: https://lenon.dev/blog/downloading-and-caching-large-files-using-python/
def download(url, destination_file):
headers = {}
if os.path.exists(destination_file):
mtime = os.path.getmtime(destination_file)
headers["if-modified-since"] = formatdate(mtime, usegmt=True)
response = requests.get(url, headers=headers, stream=True)
response.raise_for_status()
if response.status_code == requests.codes.not_modified:
return
if response.status_code == requests.codes.ok:
with open(destination_file, "wb") as f:
for chunk in response.iter_content(chunk_size=1048576):
f.write(chunk)
last_modified = response.headers.get("last-modified")
if last_modified:
new_mtime = parsedate_to_datetime(last_modified).timestamp()
os.utime(destination_file, times=(datetime.now().timestamp(), new_mtime))
def get_models_path(anatomy: str, model_type: str, output_root_dir: str):
assert anatomy in ANATOMIES.keys(), 'Invalid anatomy'
assert model_type in MODEL_TYPES.keys(), 'Invalid model type'
anatomy = ANATOMIES[anatomy]
model_type = MODEL_TYPES[model_type]
url = 'https://github.com/jpdefrutos/DDMR/releases/download/trained_models_v0/' + anatomy + '_' + model_type + '.h5'
file_path = os.path.join(output_root_dir, 'models', anatomy, model_type + '.h5')
if not os.path.exists(file_path):
os.makedirs(os.path.split(file_path)[0], exist_ok=True)
download(url, file_path)
return file_path
def load_model(weights_file_path: str, trainable: bool = False, return_registration_model: bool=True):
assert os.path.exists(weights_file_path), f'File {weights_file_path} not found'
assert weights_file_path.endswith('h5'), 'Invalid file extension. Expected .h5'
ret_val = vxm.networks.VxmDense(inshape=IMG_SHAPE[:-1],
nb_unet_features=[ENCODER_FILTERS, DECODER_FILTERS],
int_steps=0)
ret_val.load_weights(weights_file_path, by_name=True)
ret_val.trainable = trainable
if return_registration_model:
ret_val = (ret_val, ret_val.get_registration_model())
return ret_val