File size: 1,229 Bytes
475e066
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# In this file, we define download_model
# It runs during container build time to get model weights built into the container

import os
import wget
import json
import tarfile
import tempfile

def download_models(config):
    # Download parser checkpoint
    # wget.download(config['schp']['download_url'],
    #  os.path.join(os.path.dirname(__file__), config['schp']['path']))
    wget.download(config['u2net']['download_url'],
     os.path.join(os.path.dirname(__file__), config['u2net']['path']))

    # Download Super resolution model
    wget.download(config['realesrgan']['download_url'],
     os.path.join(os.path.dirname(__file__), config['realesrgan']['path']))

    # Download diffuser model checkpoint
    _, local_file_name = tempfile.mkstemp()
    local_file_name += '.tar'
    wget.download(config['diffuser']['download_url'], local_file_name)
    tar_file = tarfile.open(local_file_name)
    tar_file.extractall(os.path.join(os.path.dirname(__file__),'checkpoints/'))

if __name__ == "__main__":
    config_file = "configs/configs.json"
    config_file = os.path.join(os.path.dirname(__file__), config_file)

    with open(config_file) as fin:
        config = json.load(fin)
    download_models(config['models'])