File size: 4,203 Bytes
373af33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# -*- coding: utf-8 -*-
#
# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
# acting on behalf of its Max Planck Institute for Intelligent Systems and the
# Max Planck Institute for Biological Cybernetics. All rights reserved.
#
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
# on this computer program. You can only use this computer program if you have closed a license agreement
# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and liable to prosecution.
# Contact: [email protected]
#
#
# If you use this code in a research publication please consider citing the following:
#
# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image <https://arxiv.org/abs/1904.05866>
#
#
# Code Developed by: Nima Ghorbani <https://www.linkedin.com/in/nghorbani/>
# 2018.01.02

import os, glob
import numpy as np
from human_body_prior.tools.configurations import load_config, dump_config
import os.path as osp

def exprdir2model(expr_dir):

    if not os.path.exists(expr_dir): raise ValueError('Could not find the experiment directory: %s' % expr_dir)

    model_snapshots_dir = osp.join(expr_dir, 'snapshots')
    available_ckpts = sorted(glob.glob(osp.join(model_snapshots_dir, '*.ckpt')), key=osp.getmtime)
    assert len(available_ckpts) > 0, ValueError('No checck points found at {}'.format(model_snapshots_dir))
    trained_weigths_fname = available_ckpts[-1]

    model_ps_fname = glob.glob(osp.join('/', '/'.join(trained_weigths_fname.split('/')[:-2]), '*.yaml'))
    if len(model_ps_fname) == 0:
        model_ps_fname = glob.glob(osp.join('/'.join(trained_weigths_fname.split('/')[:-2]), '*.yaml'))

    model_ps_fname = model_ps_fname[0]
    model_ps = load_config(default_ps_fname=model_ps_fname)

    model_ps.logging.best_model_fname = trained_weigths_fname

    return model_ps, trained_weigths_fname


def load_model(expr_dir, model_code=None, remove_words_in_model_weights=None, load_only_ps=False, disable_grad=True, custom_ps = None):
    '''

    :param expr_dir:
    :param model_code: an imported module
    from supercap.train.supercap_smpl import SuperCap, then pass SuperCap to this function
    :param if True will load the model definition used for training, and not the one in current repository
    :return:
    '''
    import importlib
    import torch

    model_ps, trained_weigths_fname = exprdir2model(expr_dir)
    if load_only_ps: return model_ps
    if custom_ps is not None: model_ps = custom_ps
    assert model_code is not None, ValueError('mode_code should be provided')
    model_instance = model_code(model_ps)
    if disable_grad: # i had to do this. torch.no_grad() couldnt achieve what i was looking for
        for param in model_instance.parameters():
            param.requires_grad = False
    state_dict = torch.load(trained_weigths_fname)['state_dict']
    if remove_words_in_model_weights is not None:
        words = '{}'.format(remove_words_in_model_weights)
        state_dict = {k.replace(words, '') if k.startswith(words) else k: v for k, v in state_dict.items()}

    ## keys that were in the model trained file and not in the current model
    instance_model_keys = list(model_instance.state_dict().keys())
    trained_model_keys = list(state_dict.keys())
    wts_in_model_not_in_file = set(instance_model_keys).difference(set(trained_model_keys))
    ## keys that are in the current model not in the training weights
    wts_in_file_not_in_model = set(trained_model_keys).difference(set(instance_model_keys))
    # assert len(wts_in_model_not_in_file) == 0, ValueError('Some model weights are not present in the pretrained file. {}'.format(wts_in_model_not_in_file))

    state_dict = {k:v for k, v in state_dict.items() if k in instance_model_keys}
    model_instance.load_state_dict(state_dict, strict=False) # Todo fix the issues so that we can set the strict to true. The body model uses unnecessary registered buffers
    model_instance.eval()

    return model_instance, model_ps