CCockrum commited on
Commit
2ae1a70
·
verified ·
1 Parent(s): fe9d6c5

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +140 -0
utils.py CHANGED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, zipfile, shutil, subprocess, shlex, sys # noqa
2
+ from urllib.parse import urlparse
3
+ import re
4
+ import logging
5
+
6
+
7
+ def load_file_from_url(
8
+ url: str,
9
+ model_dir: str,
10
+ file_name: str | None = None,
11
+ overwrite: bool = False,
12
+ progress: bool = True,
13
+ ) -> str:
14
+ """Download a file from `url` into `model_dir`,
15
+ using the file present if possible.
16
+ Returns the path to the downloaded file.
17
+ """
18
+ os.makedirs(model_dir, exist_ok=True)
19
+ if not file_name:
20
+ parts = urlparse(url)
21
+ file_name = os.path.basename(parts.path)
22
+ cached_file = os.path.abspath(os.path.join(model_dir, file_name))
23
+
24
+ # Overwrite
25
+ if os.path.exists(cached_file):
26
+ if overwrite or os.path.getsize(cached_file) == 0:
27
+ remove_files(cached_file)
28
+
29
+ # Download
30
+ if not os.path.exists(cached_file):
31
+ logger.info(f'Downloading: "{url}" to {cached_file}\n')
32
+ from torch.hub import download_url_to_file
33
+
34
+ download_url_to_file(url, cached_file, progress=progress)
35
+ else:
36
+ logger.debug(cached_file)
37
+
38
+ return cached_file
39
+
40
+
41
+ def friendly_name(file: str):
42
+ if file.startswith("http"):
43
+ file = urlparse(file).path
44
+
45
+ file = os.path.basename(file)
46
+ model_name, extension = os.path.splitext(file)
47
+ return model_name, extension
48
+
49
+
50
+ def download_manager(
51
+ url: str,
52
+ path: str,
53
+ extension: str = "",
54
+ overwrite: bool = False,
55
+ progress: bool = True,
56
+ ):
57
+ url = url.strip()
58
+
59
+ name, ext = friendly_name(url)
60
+ name += ext if not extension else f".{extension}"
61
+
62
+ if url.startswith("http"):
63
+ filename = load_file_from_url(
64
+ url=url,
65
+ model_dir=path,
66
+ file_name=name,
67
+ overwrite=overwrite,
68
+ progress=progress,
69
+ )
70
+ else:
71
+ filename = path
72
+
73
+ return filename
74
+
75
+
76
+ def remove_files(file_list):
77
+ if isinstance(file_list, str):
78
+ file_list = [file_list]
79
+
80
+ for file in file_list:
81
+ if os.path.exists(file):
82
+ os.remove(file)
83
+
84
+
85
+ def remove_directory_contents(directory_path):
86
+ """
87
+ Removes all files and subdirectories within a directory.
88
+ Parameters:
89
+ directory_path (str): Path to the directory whose
90
+ contents need to be removed.
91
+ """
92
+ if os.path.exists(directory_path):
93
+ for filename in os.listdir(directory_path):
94
+ file_path = os.path.join(directory_path, filename)
95
+ try:
96
+ if os.path.isfile(file_path):
97
+ os.remove(file_path)
98
+ elif os.path.isdir(file_path):
99
+ shutil.rmtree(file_path)
100
+ except Exception as e:
101
+ logger.error(f"Failed to delete {file_path}. Reason: {e}")
102
+ logger.info(f"Content in '{directory_path}' removed.")
103
+ else:
104
+ logger.error(f"Directory '{directory_path}' does not exist.")
105
+
106
+
107
+ # Create directory if not exists
108
+ def create_directories(directory_path):
109
+ if isinstance(directory_path, str):
110
+ directory_path = [directory_path]
111
+ for one_dir_path in directory_path:
112
+ if not os.path.exists(one_dir_path):
113
+ os.makedirs(one_dir_path)
114
+ logger.debug(f"Directory '{one_dir_path}' created.")
115
+
116
+
117
+ def setup_logger(name_log):
118
+ logger = logging.getLogger(name_log)
119
+ logger.setLevel(logging.INFO)
120
+
121
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
122
+ _default_handler.flush = sys.stderr.flush
123
+ logger.addHandler(_default_handler)
124
+
125
+ logger.propagate = False
126
+
127
+ handlers = logger.handlers
128
+
129
+ for handler in handlers:
130
+ formatter = logging.Formatter("[%(levelname)s] >> %(message)s")
131
+ handler.setFormatter(formatter)
132
+
133
+ # logger.handlers
134
+
135
+ return logger
136
+
137
+
138
+ logger = setup_logger("ss")
139
+ logger.setLevel(logging.INFO)
140
+