Spaces:
Running
Running
| from __future__ import print_function | |
| import os | |
| from subprocess import call | |
| from builtins import input | |
| curr_folder = os.path.basename(os.path.normpath(os.getcwd())) | |
| weights_filename = 'pytorch_model.bin' | |
| weights_folder = 'model' | |
| weights_path = '{}/{}'.format(weights_folder, weights_filename) | |
| if curr_folder == 'scripts': | |
| weights_path = '../' + weights_path | |
| weights_download_link = 'https://www.dropbox.com/s/q8lax9ary32c7t9/pytorch_model.bin?dl=0#' | |
| MB_FACTOR = float(1<<20) | |
| def prompt(): | |
| while True: | |
| valid = { | |
| 'y': True, | |
| 'ye': True, | |
| 'yes': True, | |
| 'n': False, | |
| 'no': False, | |
| } | |
| choice = input().lower() | |
| if choice in valid: | |
| return valid[choice] | |
| else: | |
| print('Please respond with \'y\' or \'n\' (or \'yes\' or \'no\')') | |
| download = True | |
| if os.path.exists(weights_path): | |
| print('Weight file already exists at {}. Would you like to redownload it anyway? [y/n]'.format(weights_path)) | |
| download = prompt() | |
| already_exists = True | |
| else: | |
| already_exists = False | |
| if download: | |
| print('About to download the pretrained weights file from {}'.format(weights_download_link)) | |
| if already_exists == False: | |
| print('The size of the file is roughly 85MB. Continue? [y/n]') | |
| else: | |
| os.unlink(weights_path) | |
| if already_exists or prompt(): | |
| print('Downloading...') | |
| #urllib.urlretrieve(weights_download_link, weights_path) | |
| #with open(weights_path,'wb') as f: | |
| # f.write(requests.get(weights_download_link).content) | |
| # downloading using wget due to issues with urlretrieve and requests | |
| sys_call = 'wget {} -O {}'.format(weights_download_link, os.path.abspath(weights_path)) | |
| print("Running system call: {}".format(sys_call)) | |
| call(sys_call, shell=True) | |
| if os.path.getsize(weights_path) / MB_FACTOR < 80: | |
| raise ValueError("Download finished, but the resulting file is too small! " + | |
| "It\'s only {} bytes.".format(os.path.getsize(weights_path))) | |
| print('Downloaded weights to {}'.format(weights_path)) | |
| else: | |
| print('Exiting.') | |