r3gm commited on
Commit
eb210c4
·
verified ·
1 Parent(s): d3450c5
Files changed (2) hide show
  1. app.py +0 -0
  2. utils.py +714 -562
app.py CHANGED
The diff for this file is too large to render. See raw diff
 
utils.py CHANGED
@@ -1,562 +1,714 @@
1
- import os
2
- import re
3
- import gradio as gr
4
- from constants import (
5
- DIFFUSERS_FORMAT_LORAS,
6
- CIVITAI_API_KEY,
7
- HF_TOKEN,
8
- MODEL_TYPE_CLASS,
9
- DIRECTORY_LORAS,
10
- DIRECTORY_MODELS,
11
- DIFFUSECRAFT_CHECKPOINT_NAME,
12
- CACHE_HF_ROOT,
13
- CACHE_HF,
14
- STORAGE_ROOT,
15
- )
16
- from huggingface_hub import HfApi
17
- from huggingface_hub import snapshot_download
18
- from diffusers import DiffusionPipeline
19
- from huggingface_hub import model_info as model_info_data
20
- from diffusers.pipelines.pipeline_loading_utils import variant_compatible_siblings
21
- from stablepy.diffusers_vanilla.utils import checkpoint_model_type
22
- from pathlib import PosixPath
23
- from unidecode import unidecode
24
- import urllib.parse
25
- import copy
26
- import requests
27
- from requests.adapters import HTTPAdapter
28
- from urllib3.util import Retry
29
- import shutil
30
- import subprocess
31
-
32
- IS_ZERO_GPU = bool(os.getenv("SPACES_ZERO_GPU"))
33
- USER_AGENT = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
34
-
35
-
36
- def request_json_data(url):
37
- model_version_id = url.split('/')[-1]
38
- if "?modelVersionId=" in model_version_id:
39
- match = re.search(r'modelVersionId=(\d+)', url)
40
- model_version_id = match.group(1)
41
-
42
- endpoint_url = f"https://civitai.com/api/v1/model-versions/{model_version_id}"
43
-
44
- params = {}
45
- headers = {'User-Agent': USER_AGENT, 'content-type': 'application/json'}
46
- session = requests.Session()
47
- retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
48
- session.mount("https://", HTTPAdapter(max_retries=retries))
49
-
50
- try:
51
- result = session.get(endpoint_url, params=params, headers=headers, stream=True, timeout=(3.0, 15))
52
- result.raise_for_status()
53
- json_data = result.json()
54
- return json_data if json_data else None
55
- except Exception as e:
56
- print(f"Error: {e}")
57
- return None
58
-
59
-
60
- class ModelInformation:
61
- def __init__(self, json_data):
62
- self.model_version_id = json_data.get("id", "")
63
- self.model_id = json_data.get("modelId", "")
64
- self.download_url = json_data.get("downloadUrl", "")
65
- self.model_url = f"https://civitai.com/models/{self.model_id}?modelVersionId={self.model_version_id}"
66
- self.filename_url = next(
67
- (v.get("name", "") for v in json_data.get("files", []) if str(self.model_version_id) in v.get("downloadUrl", "") and v.get("type", "Model") == "Model"), ""
68
- )
69
- self.filename_url = self.filename_url if self.filename_url else ""
70
- self.description = json_data.get("description", "")
71
- if self.description is None:
72
- self.description = ""
73
- self.model_name = json_data.get("model", {}).get("name", "")
74
- self.model_type = json_data.get("model", {}).get("type", "")
75
- self.nsfw = json_data.get("model", {}).get("nsfw", False)
76
- self.poi = json_data.get("model", {}).get("poi", False)
77
- self.images = [img.get("url", "") for img in json_data.get("images", [])]
78
- self.example_prompt = json_data.get("trainedWords", [""])[0] if json_data.get("trainedWords") else ""
79
- self.original_json = copy.deepcopy(json_data)
80
-
81
-
82
- def get_civit_params(url):
83
- try:
84
- json_data = request_json_data(url)
85
- mdc = ModelInformation(json_data)
86
- if mdc.download_url and mdc.filename_url:
87
- return mdc.download_url, mdc.filename_url, mdc.model_url
88
- else:
89
- ValueError("Invalid Civitai model URL")
90
- except Exception as e:
91
- print(f"Error retrieving Civitai metadata: {e} — fallback to direct download")
92
- return url, None, None
93
-
94
-
95
- def civ_redirect_down(url, dir_, civitai_api_key, romanize, alternative_name):
96
- filename_base = filename = None
97
-
98
- if alternative_name:
99
- output_path = os.path.join(dir_, alternative_name)
100
- if os.path.exists(output_path):
101
- return output_path, alternative_name
102
-
103
- # Follow the redirect to get the actual download URL
104
- curl_command = (
105
- f'curl -L -sI --connect-timeout 5 --max-time 5 '
106
- f'-H "Content-Type: application/json" '
107
- f'-H "Authorization: Bearer {civitai_api_key}" "{url}"'
108
- )
109
-
110
- headers = os.popen(curl_command).read()
111
-
112
- # Look for the redirected "Location" URL
113
- location_match = re.search(r'location: (.+)', headers, re.IGNORECASE)
114
-
115
- if location_match:
116
- redirect_url = location_match.group(1).strip()
117
-
118
- # Extract the filename from the redirect URL's "Content-Disposition"
119
- filename_match = re.search(r'filename%3D%22(.+?)%22', redirect_url)
120
- if filename_match:
121
- encoded_filename = filename_match.group(1)
122
- # Decode the URL-encoded filename
123
- decoded_filename = urllib.parse.unquote(encoded_filename)
124
-
125
- filename = unidecode(decoded_filename) if romanize else decoded_filename
126
- # print(f"Filename redirect: {filename}")
127
-
128
- filename_base = alternative_name if alternative_name else filename
129
- if not filename_base:
130
- return None, None
131
- elif os.path.exists(os.path.join(dir_, filename_base)):
132
- return os.path.join(dir_, filename_base), filename_base
133
-
134
- aria2_command = (
135
- f'aria2c --console-log-level=error --summary-interval=10 -c -x 16 '
136
- f'-k 1M -s 16 -d "{dir_}" -o "{filename_base}" "{redirect_url}"'
137
- )
138
- r_code = os.system(aria2_command) # noqa
139
-
140
- # if r_code != 0:
141
- # raise RuntimeError(f"Failed to download file: {filename_base}. Error code: {r_code}")
142
-
143
- output_path = os.path.join(dir_, filename_base)
144
- if not os.path.exists(output_path):
145
- return None, filename_base
146
-
147
- return output_path, filename_base
148
-
149
-
150
- def civ_api_down(url, dir_, civitai_api_key, civ_filename):
151
- """
152
- This method is susceptible to being blocked because it generates a lot of temp redirect links with aria2c.
153
- If an API key limit is reached, generating a new API key and using it can fix the issue.
154
- """
155
- output_path = None
156
-
157
- url_dl = url + f"?token={civitai_api_key}"
158
- if not civ_filename:
159
- aria2_command = f'aria2c -c -x 1 -s 1 -d "{dir_}" "{url_dl}"'
160
- os.system(aria2_command)
161
- else:
162
- output_path = os.path.join(dir_, civ_filename)
163
- if not os.path.exists(output_path):
164
- aria2_command = (
165
- f'aria2c --console-log-level=error --summary-interval=10 -c -x 16 '
166
- f'-k 1M -s 16 -d "{dir_}" -o "{civ_filename}" "{url_dl}"'
167
- )
168
- os.system(aria2_command)
169
-
170
- return output_path
171
-
172
-
173
- def drive_down(url, dir_):
174
- import gdown
175
-
176
- output_path = None
177
-
178
- drive_id, _ = gdown.parse_url.parse_url(url, warning=False)
179
- dir_files = os.listdir(dir_)
180
-
181
- for dfile in dir_files:
182
- if drive_id in dfile:
183
- output_path = os.path.join(dir_, dfile)
184
- break
185
-
186
- if not output_path:
187
- original_path = gdown.download(url, f"{dir_}/", fuzzy=True)
188
-
189
- dir_name, base_name = os.path.split(original_path)
190
- name, ext = base_name.rsplit(".", 1)
191
- new_name = f"{name}_{drive_id}.{ext}"
192
- output_path = os.path.join(dir_name, new_name)
193
-
194
- os.rename(original_path, output_path)
195
-
196
- return output_path
197
-
198
-
199
- def hf_down(url, dir_, hf_token, romanize):
200
- url = url.replace("?download=true", "")
201
- # url = urllib.parse.quote(url, safe=':/') # fix encoding
202
-
203
- filename = unidecode(url.split('/')[-1]) if romanize else url.split('/')[-1]
204
- output_path = os.path.join(dir_, filename)
205
-
206
- if os.path.exists(output_path):
207
- return output_path
208
-
209
- if "/blob/" in url:
210
- url = url.replace("/blob/", "/resolve/")
211
-
212
- if hf_token:
213
- user_header = f'"Authorization: Bearer {hf_token}"'
214
- os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {dir_} -o {filename}")
215
- else:
216
- os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {dir_} -o {filename}")
217
-
218
- return output_path
219
-
220
-
221
- def download_things(directory, url, hf_token="", civitai_api_key="", romanize=False):
222
- url = url.strip()
223
- downloaded_file_path = None
224
-
225
- if "drive.google.com" in url:
226
- downloaded_file_path = drive_down(url, directory)
227
- elif "huggingface.co" in url:
228
- downloaded_file_path = hf_down(url, directory, hf_token, romanize)
229
- elif "civitai.com" in url:
230
- if not civitai_api_key:
231
- msg = "You need an API key to download Civitai models."
232
- print(f"\033[91m{msg}\033[0m")
233
- gr.Warning(msg)
234
- return None
235
-
236
- url, civ_filename, civ_page = get_civit_params(url)
237
- if civ_page and not IS_ZERO_GPU:
238
- print(f"\033[92mCivitai model: {civ_filename} [page: {civ_page}]\033[0m")
239
-
240
- downloaded_file_path, civ_filename = civ_redirect_down(url, directory, civitai_api_key, romanize, civ_filename)
241
-
242
- if not downloaded_file_path:
243
- msg = (
244
- "Download failed.\n"
245
- "If this is due to an API limit, generating a new API key may resolve the issue.\n"
246
- "Attempting to download using the old method..."
247
- )
248
- print(msg)
249
- gr.Warning(msg)
250
- downloaded_file_path = civ_api_down(url, directory, civitai_api_key, civ_filename)
251
- else:
252
- os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
253
-
254
- return downloaded_file_path
255
-
256
-
257
- def get_model_list(directory_path):
258
- model_list = []
259
- valid_extensions = {'.ckpt', '.pt', '.pth', '.safetensors', '.bin'}
260
-
261
- for filename in os.listdir(directory_path):
262
- if os.path.splitext(filename)[1] in valid_extensions:
263
- # name_without_extension = os.path.splitext(filename)[0]
264
- file_path = os.path.join(directory_path, filename)
265
- # model_list.append((name_without_extension, file_path))
266
- model_list.append(file_path)
267
- print('\033[34mFILE: ' + file_path + '\033[0m')
268
- return model_list
269
-
270
-
271
- def extract_parameters(input_string):
272
- parameters = {}
273
- input_string = input_string.replace("\n", "")
274
-
275
- if "Negative prompt:" not in input_string:
276
- if "Steps:" in input_string:
277
- input_string = input_string.replace("Steps:", "Negative prompt: Steps:")
278
- else:
279
- msg = "Generation data is invalid."
280
- gr.Warning(msg)
281
- print(msg)
282
- parameters["prompt"] = input_string
283
- return parameters
284
-
285
- parm = input_string.split("Negative prompt:")
286
- parameters["prompt"] = parm[0].strip()
287
- if "Steps:" not in parm[1]:
288
- parameters["neg_prompt"] = parm[1].strip()
289
- return parameters
290
- parm = parm[1].split("Steps:")
291
- parameters["neg_prompt"] = parm[0].strip()
292
- input_string = "Steps:" + parm[1]
293
-
294
- # Extracting Steps
295
- steps_match = re.search(r'Steps: (\d+)', input_string)
296
- if steps_match:
297
- parameters['Steps'] = int(steps_match.group(1))
298
-
299
- # Extracting Size
300
- size_match = re.search(r'Size: (\d+x\d+)', input_string)
301
- if size_match:
302
- parameters['Size'] = size_match.group(1)
303
- width, height = map(int, parameters['Size'].split('x'))
304
- parameters['width'] = width
305
- parameters['height'] = height
306
-
307
- # Extracting other parameters
308
- other_parameters = re.findall(r'([^,:]+): (.*?)(?=, [^,:]+:|$)', input_string)
309
- for param in other_parameters:
310
- parameters[param[0].strip()] = param[1].strip('"')
311
-
312
- return parameters
313
-
314
-
315
- def get_my_lora(link_url, romanize):
316
- l_name = ""
317
- for url in [url.strip() for url in link_url.split(',')]:
318
- if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
319
- l_name = download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY, romanize)
320
- new_lora_model_list = get_model_list(DIRECTORY_LORAS)
321
- new_lora_model_list.insert(0, "None")
322
- new_lora_model_list = new_lora_model_list + DIFFUSERS_FORMAT_LORAS
323
- msg_lora = "Downloaded"
324
- if l_name:
325
- msg_lora += f": <b>{l_name}</b>"
326
- print(msg_lora)
327
-
328
- return gr.update(
329
- choices=new_lora_model_list
330
- ), gr.update(
331
- choices=new_lora_model_list
332
- ), gr.update(
333
- choices=new_lora_model_list
334
- ), gr.update(
335
- choices=new_lora_model_list
336
- ), gr.update(
337
- choices=new_lora_model_list
338
- ), gr.update(
339
- choices=new_lora_model_list
340
- ), gr.update(
341
- choices=new_lora_model_list
342
- ), gr.update(
343
- value=msg_lora
344
- )
345
-
346
-
347
- def info_html(json_data, title, subtitle):
348
- return f"""
349
- <div style='padding: 0; border-radius: 10px;'>
350
- <p style='margin: 0; font-weight: bold;'>{title}</p>
351
- <details>
352
- <summary>Details</summary>
353
- <p style='margin: 0; font-weight: bold;'>{subtitle}</p>
354
- </details>
355
- </div>
356
- """
357
-
358
-
359
- def get_model_type(repo_id: str):
360
- api = HfApi(token=os.environ.get("HF_TOKEN")) # if use private or gated model
361
- default = "SD 1.5"
362
- try:
363
- if os.path.exists(repo_id):
364
- tag, _, _, _ = checkpoint_model_type(repo_id)
365
- return DIFFUSECRAFT_CHECKPOINT_NAME[tag]
366
- else:
367
- model = api.model_info(repo_id=repo_id, timeout=5.0)
368
- tags = model.tags
369
- for tag in tags:
370
- if tag in MODEL_TYPE_CLASS.keys():
371
- return MODEL_TYPE_CLASS.get(tag, default)
372
-
373
- except Exception:
374
- return default
375
- return default
376
-
377
-
378
- def restart_space(repo_id: str, factory_reboot: bool):
379
- api = HfApi(token=os.environ.get("HF_TOKEN"))
380
- try:
381
- runtime = api.get_space_runtime(repo_id=repo_id)
382
- if runtime.stage == "RUNNING":
383
- api.restart_space(repo_id=repo_id, factory_reboot=factory_reboot)
384
- print(f"Restarting space: {repo_id}")
385
- else:
386
- print(f"Space {repo_id} is in stage: {runtime.stage}")
387
- except Exception as e:
388
- print(e)
389
-
390
-
391
- def extract_exif_data(image):
392
- if image is None:
393
- return ""
394
-
395
- try:
396
- metadata_keys = ['parameters', 'metadata', 'prompt', 'Comment']
397
-
398
- for key in metadata_keys:
399
- if key in image.info:
400
- return image.info[key]
401
-
402
- return str(image.info)
403
-
404
- except Exception as e:
405
- return f"Error extracting metadata: {str(e)}"
406
-
407
-
408
- def create_mask_now(img, invert):
409
- import numpy as np
410
- import time
411
-
412
- time.sleep(0.5)
413
-
414
- transparent_image = img["layers"][0]
415
-
416
- # Extract the alpha channel
417
- alpha_channel = np.array(transparent_image)[:, :, 3]
418
-
419
- # Create a binary mask by thresholding the alpha channel
420
- binary_mask = alpha_channel > 1
421
-
422
- if invert:
423
- print("Invert")
424
- # Invert the binary mask so that the drawn shape is white and the rest is black
425
- binary_mask = np.invert(binary_mask)
426
-
427
- # Convert the binary mask to a 3-channel RGB mask
428
- rgb_mask = np.stack((binary_mask,) * 3, axis=-1)
429
-
430
- # Convert the mask to uint8
431
- rgb_mask = rgb_mask.astype(np.uint8) * 255
432
-
433
- return img["background"], rgb_mask
434
-
435
-
436
- def download_diffuser_repo(repo_name: str, model_type: str, revision: str = "main", token=True):
437
-
438
- variant = None
439
- if token is True and not os.environ.get("HF_TOKEN"):
440
- token = None
441
-
442
- if model_type == "SDXL":
443
- info = model_info_data(
444
- repo_name,
445
- token=token,
446
- revision=revision,
447
- timeout=5.0,
448
- )
449
-
450
- filenames = {sibling.rfilename for sibling in info.siblings}
451
- model_filenames, variant_filenames = variant_compatible_siblings(
452
- filenames, variant="fp16"
453
- )
454
-
455
- if len(variant_filenames):
456
- variant = "fp16"
457
-
458
- if model_type == "FLUX":
459
- cached_folder = snapshot_download(
460
- repo_id=repo_name,
461
- allow_patterns="transformer/*"
462
- )
463
- else:
464
- cached_folder = DiffusionPipeline.download(
465
- pretrained_model_name=repo_name,
466
- force_download=False,
467
- token=token,
468
- revision=revision,
469
- # mirror="https://hf-mirror.com",
470
- variant=variant,
471
- use_safetensors=True,
472
- trust_remote_code=False,
473
- timeout=5.0,
474
- )
475
-
476
- if isinstance(cached_folder, PosixPath):
477
- cached_folder = cached_folder.as_posix()
478
-
479
- # Task model
480
- # from huggingface_hub import hf_hub_download
481
- # hf_hub_download(
482
- # task_model,
483
- # filename="diffusion_pytorch_model.safetensors", # fix fp16 variant
484
- # )
485
-
486
- return cached_folder
487
-
488
-
489
- def get_folder_size_gb(folder_path):
490
- result = subprocess.run(["du", "-s", folder_path], capture_output=True, text=True)
491
-
492
- total_size_kb = int(result.stdout.split()[0])
493
- total_size_gb = total_size_kb / (1024 ** 2)
494
-
495
- return total_size_gb
496
-
497
-
498
- def get_used_storage_gb(path_storage=STORAGE_ROOT):
499
- try:
500
- used_gb = get_folder_size_gb(path_storage)
501
- print(f"Used Storage: {used_gb:.2f} GB")
502
- except Exception as e:
503
- used_gb = 999
504
- print(f"Error while retrieving the used storage: {e}.")
505
-
506
- return used_gb
507
-
508
-
509
- def delete_model(removal_candidate):
510
- print(f"Removing: {removal_candidate}")
511
-
512
- if os.path.exists(removal_candidate):
513
- os.remove(removal_candidate)
514
- else:
515
- diffusers_model = f"{CACHE_HF}{DIRECTORY_MODELS}--{removal_candidate.replace('/', '--')}"
516
- if os.path.isdir(diffusers_model):
517
- shutil.rmtree(diffusers_model)
518
-
519
-
520
- def clear_hf_cache():
521
- """
522
- Clears the entire Hugging Face cache at ~/.cache/huggingface.
523
- Hugging Face will re-download models as needed later.
524
- """
525
- try:
526
- if os.path.exists(CACHE_HF_ROOT):
527
- shutil.rmtree(CACHE_HF_ROOT, ignore_errors=True)
528
- print(f"Hugging Face cache cleared: {CACHE_HF_ROOT}")
529
- else:
530
- print(f"No Hugging Face cache found at: {CACHE_HF_ROOT}")
531
- except Exception as e:
532
- print(f"Error clearing Hugging Face cache: {e}")
533
-
534
-
535
- def progress_step_bar(step, total):
536
- # Calculate the percentage for the progress bar width
537
- percentage = min(100, ((step / total) * 100))
538
-
539
- return f"""
540
- <div style="position: relative; width: 100%; background-color: gray; border-radius: 5px; overflow: hidden;">
541
- <div style="width: {percentage}%; height: 17px; background-color: #800080; transition: width 0.5s;"></div>
542
- <div style="position: absolute; width: 100%; text-align: center; color: white; top: 0; line-height: 19px; font-size: 13px;">
543
- {int(percentage)}%
544
- </div>
545
- </div>
546
- """
547
-
548
-
549
- def html_template_message(msg):
550
- return f"""
551
- <div style="position: relative; width: 100%; background-color: gray; border-radius: 5px; overflow: hidden;">
552
- <div style="width: 0%; height: 17px; background-color: #800080; transition: width 0.5s;"></div>
553
- <div style="position: absolute; width: 100%; text-align: center; color: white; top: 0; line-height: 19px; font-size: 14px; font-weight: bold; text-shadow: 1px 1px 2px black;">
554
- {msg}
555
- </div>
556
- </div>
557
- """
558
-
559
-
560
- def escape_html(text):
561
- """Escapes HTML special characters in the input text."""
562
- return text.replace("<", "&lt;").replace(">", "&gt;").replace("\n", "<br>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import gradio as gr
4
+ from constants import (
5
+ DIFFUSERS_FORMAT_LORAS,
6
+ CIVITAI_API_KEY,
7
+ HF_TOKEN,
8
+ MODEL_TYPE_CLASS,
9
+ DIRECTORY_LORAS,
10
+ DIRECTORY_MODELS,
11
+ DIFFUSECRAFT_CHECKPOINT_NAME,
12
+ CACHE_HF_ROOT,
13
+ CACHE_HF,
14
+ STORAGE_ROOT,
15
+ )
16
+ from huggingface_hub import HfApi, get_hf_file_metadata, snapshot_download
17
+ from diffusers import DiffusionPipeline
18
+ from huggingface_hub import model_info as model_info_data
19
+ from diffusers.pipelines.pipeline_loading_utils import variant_compatible_siblings
20
+ from stablepy.diffusers_vanilla.utils import checkpoint_model_type
21
+ from pathlib import PosixPath
22
+ from unidecode import unidecode
23
+ import urllib.parse
24
+ import copy
25
+ import requests
26
+ from requests.adapters import HTTPAdapter
27
+ from urllib3.util import Retry
28
+ import shutil
29
+ import subprocess
30
+ import json
31
+ import html as _html
32
+
33
+ IS_ZERO_GPU = bool(os.getenv("SPACES_ZERO_GPU"))
34
+ USER_AGENT = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
35
+ MODEL_ARCH = {
36
+ 'stable-diffusion-xl-v1-base/lora': "Stable Diffusion XL (Illustrious, Pony, NoobAI)",
37
+ 'stable-diffusion-v1/lora': "Stable Diffusion 1.5",
38
+ 'flux-1-dev/lora': "Flux",
39
+ }
40
+
41
+
42
+ def read_safetensors_header_from_url(url: str):
43
+ """Read safetensors header from a remote Hugging Face file."""
44
+ meta = get_hf_file_metadata(url)
45
+
46
+ # Step 1: first 8 bytes → header length
47
+ resp = requests.get(meta.location, headers={"Range": "bytes=0-7"})
48
+ resp.raise_for_status()
49
+ header_len = int.from_bytes(resp.content, "little")
50
+
51
+ # Step 2: fetch full header JSON
52
+ end = 8 + header_len - 1
53
+ resp = requests.get(meta.location, headers={"Range": f"bytes=8-{end}"})
54
+ resp.raise_for_status()
55
+ header_json = resp.content.decode("utf-8")
56
+
57
+ return json.loads(header_json)
58
+
59
+
60
+ def read_safetensors_header_from_file(path: str):
61
+ """Read safetensors header from a local file."""
62
+ with open(path, "rb") as f:
63
+ # Step 1: first 8 bytes → header length
64
+ header_len = int.from_bytes(f.read(8), "little")
65
+
66
+ # Step 2: read header JSON
67
+ header_json = f.read(header_len).decode("utf-8")
68
+
69
+ return json.loads(header_json)
70
+
71
+
72
+ class LoraHeaderInformation:
73
+ """
74
+ Encapsulates parsed info from a LoRA JSON header and provides
75
+ a compact HTML summary via .to_html().
76
+ """
77
+
78
+ def __init__(self, json_data):
79
+ self.original_json = copy.deepcopy(json_data or {})
80
+
81
+ # Check if text encoder was trained
82
+ # guard for json_data being a mapping
83
+ try:
84
+ self.text_encoder_trained = any("text_model" in ln for ln in json_data)
85
+ except Exception:
86
+ self.text_encoder_trained = False
87
+
88
+ # Metadata (may be None)
89
+ metadata = (json_data or {}).get("__metadata__", None)
90
+ self.metadata = metadata
91
+
92
+ # Default values
93
+ self.architecture = "undefined"
94
+ self.prediction_type = "undefined"
95
+ self.base_model = "undefined"
96
+ self.author = "undefined"
97
+ self.title = "undefined"
98
+ self.common_tags_list = []
99
+
100
+ if metadata:
101
+ self.architecture = MODEL_ARCH.get(
102
+ metadata.get('modelspec.architecture', None),
103
+ "undefined"
104
+ )
105
+
106
+ self.prediction_type = metadata.get('modelspec.prediction_type', "undefined")
107
+ self.base_model = metadata.get('ss_sd_model_name', "undefined")
108
+ self.author = metadata.get('modelspec.author', "undefined")
109
+ self.title = metadata.get('modelspec.title', "undefined")
110
+
111
+ base_model_hash = metadata.get('ss_new_sd_model_hash', None) # SHA256
112
+ # AUTOV1 ss_sd_model_hash
113
+ # https://civitai.com/api/v1/model-versions/by-hash/{base_model_hash} # Info
114
+ if base_model_hash:
115
+ self.base_model += f" hash={base_model_hash}"
116
+
117
+ # Extract tags
118
+ try:
119
+ tags = metadata.get('ss_tag_frequency') if "ss_tag_frequency" in metadata else metadata.get('ss_datasets', "")
120
+ tags = json.loads(tags) if tags else ""
121
+
122
+ if isinstance(tags, list):
123
+ tags = tags[0].get("tag_frequency", {})
124
+
125
+ if tags:
126
+ self.common_tags_list = list(tags[list(tags.keys())[0]].keys())
127
+ except Exception:
128
+ self.common_tags_list = []
129
+
130
+ def to_dict(self):
131
+ """Return a plain dict summary of parsed fields."""
132
+ return {
133
+ "architecture": self.architecture,
134
+ "prediction_type": self.prediction_type,
135
+ "base_model": self.base_model,
136
+ "author": self.author,
137
+ "title": self.title,
138
+ "text_encoder_trained": bool(self.text_encoder_trained),
139
+ "common_tags": self.common_tags_list,
140
+ }
141
+
142
+ def to_html(self, limit_tags=20):
143
+ """
144
+ Return a compact HTML snippet (string) showing the parsed info
145
+ in a small font. Values are HTML-escaped.
146
+ """
147
+ # helper to escape
148
+ esc = _html.escape
149
+
150
+ rows = [
151
+ ("Title", esc(str(self.title))),
152
+ ("Author", esc(str(self.author))),
153
+ ("Architecture", esc(str(self.architecture))),
154
+ ("Base model", esc(str(self.base_model))),
155
+ ("Prediction type", esc(str(self.prediction_type))),
156
+ ("Text encoder trained", esc(str(self.text_encoder_trained))),
157
+ ("Reference tags", esc(str(", ".join(self.common_tags_list[:limit_tags])))),
158
+ ]
159
+
160
+ # small, compact table with inline styling (small font)
161
+ html_rows = "".join(
162
+ f"<tr><th style='text-align:left;padding:2px 6px;white-space:nowrap'>{k}</th>"
163
+ f"<td style='padding:2px 6px'>{v}</td></tr>"
164
+ for k, v in rows
165
+ )
166
+
167
+ html_snippet = (
168
+ "<div style='font-family:system-ui, -apple-system, \"Segoe UI\", Roboto, "
169
+ "Helvetica, Arial, \"Noto Sans\", sans-serif; font-size:12px; line-height:1.2; "
170
+ "'>"
171
+ f"<table style='border-collapse:collapse; font-size:12px;'>"
172
+ f"{html_rows}"
173
+ "</table>"
174
+ "</div>"
175
+ )
176
+
177
+ return html_snippet
178
+
179
+
180
+ def request_json_data(url):
181
+ model_version_id = url.split('/')[-1]
182
+ if "?modelVersionId=" in model_version_id:
183
+ match = re.search(r'modelVersionId=(\d+)', url)
184
+ model_version_id = match.group(1)
185
+
186
+ endpoint_url = f"https://civitai.com/api/v1/model-versions/{model_version_id}"
187
+
188
+ params = {}
189
+ headers = {'User-Agent': USER_AGENT, 'content-type': 'application/json'}
190
+ session = requests.Session()
191
+ retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
192
+ session.mount("https://", HTTPAdapter(max_retries=retries))
193
+
194
+ try:
195
+ result = session.get(endpoint_url, params=params, headers=headers, stream=True, timeout=(3.0, 15))
196
+ result.raise_for_status()
197
+ json_data = result.json()
198
+ return json_data if json_data else None
199
+ except Exception as e:
200
+ print(f"Error: {e}")
201
+ return None
202
+
203
+
204
+ class ModelInformation:
205
+ def __init__(self, json_data):
206
+ self.model_version_id = json_data.get("id", "")
207
+ self.model_id = json_data.get("modelId", "")
208
+ self.download_url = json_data.get("downloadUrl", "")
209
+ self.model_url = f"https://civitai.com/models/{self.model_id}?modelVersionId={self.model_version_id}"
210
+ self.filename_url = next(
211
+ (v.get("name", "") for v in json_data.get("files", []) if str(self.model_version_id) in v.get("downloadUrl", "") and v.get("type", "Model") == "Model"), ""
212
+ )
213
+ self.filename_url = self.filename_url if self.filename_url else ""
214
+ self.description = json_data.get("description", "")
215
+ if self.description is None:
216
+ self.description = ""
217
+ self.model_name = json_data.get("model", {}).get("name", "")
218
+ self.model_type = json_data.get("model", {}).get("type", "")
219
+ self.nsfw = json_data.get("model", {}).get("nsfw", False)
220
+ self.poi = json_data.get("model", {}).get("poi", False)
221
+ self.images = [img.get("url", "") for img in json_data.get("images", [])]
222
+ self.example_prompt = json_data.get("trainedWords", [""])[0] if json_data.get("trainedWords") else ""
223
+ self.original_json = copy.deepcopy(json_data)
224
+
225
+
226
+ def get_civit_params(url):
227
+ try:
228
+ json_data = request_json_data(url)
229
+ mdc = ModelInformation(json_data)
230
+ if mdc.download_url and mdc.filename_url:
231
+ return mdc.download_url, mdc.filename_url, mdc.model_url
232
+ else:
233
+ ValueError("Invalid Civitai model URL")
234
+ except Exception as e:
235
+ print(f"Error retrieving Civitai metadata: {e} — fallback to direct download")
236
+ return url, None, None
237
+
238
+
239
+ def civ_redirect_down(url, dir_, civitai_api_key, romanize, alternative_name):
240
+ filename_base = filename = None
241
+
242
+ if alternative_name:
243
+ output_path = os.path.join(dir_, alternative_name)
244
+ if os.path.exists(output_path):
245
+ return output_path, alternative_name
246
+
247
+ # Follow the redirect to get the actual download URL
248
+ curl_command = (
249
+ f'curl -L -sI --connect-timeout 5 --max-time 5 '
250
+ f'-H "Content-Type: application/json" '
251
+ f'-H "Authorization: Bearer {civitai_api_key}" "{url}"'
252
+ )
253
+
254
+ headers = os.popen(curl_command).read()
255
+
256
+ # Look for the redirected "Location" URL
257
+ location_match = re.search(r'location: (.+)', headers, re.IGNORECASE)
258
+
259
+ if location_match:
260
+ redirect_url = location_match.group(1).strip()
261
+
262
+ # Extract the filename from the redirect URL's "Content-Disposition"
263
+ filename_match = re.search(r'filename%3D%22(.+?)%22', redirect_url)
264
+ if filename_match:
265
+ encoded_filename = filename_match.group(1)
266
+ # Decode the URL-encoded filename
267
+ decoded_filename = urllib.parse.unquote(encoded_filename)
268
+
269
+ filename = unidecode(decoded_filename) if romanize else decoded_filename
270
+ # print(f"Filename redirect: {filename}")
271
+
272
+ filename_base = alternative_name if alternative_name else filename
273
+ if not filename_base:
274
+ return None, None
275
+ elif os.path.exists(os.path.join(dir_, filename_base)):
276
+ return os.path.join(dir_, filename_base), filename_base
277
+
278
+ aria2_command = (
279
+ f'aria2c --console-log-level=error --summary-interval=10 -c -x 16 '
280
+ f'-k 1M -s 16 -d "{dir_}" -o "{filename_base}" "{redirect_url}"'
281
+ )
282
+ r_code = os.system(aria2_command) # noqa
283
+
284
+ # if r_code != 0:
285
+ # raise RuntimeError(f"Failed to download file: {filename_base}. Error code: {r_code}")
286
+
287
+ output_path = os.path.join(dir_, filename_base)
288
+ if not os.path.exists(output_path):
289
+ return None, filename_base
290
+
291
+ return output_path, filename_base
292
+
293
+
294
+ def civ_api_down(url, dir_, civitai_api_key, civ_filename):
295
+ """
296
+ This method is susceptible to being blocked because it generates a lot of temp redirect links with aria2c.
297
+ If an API key limit is reached, generating a new API key and using it can fix the issue.
298
+ """
299
+ output_path = None
300
+
301
+ url_dl = url + f"?token={civitai_api_key}"
302
+ if not civ_filename:
303
+ aria2_command = f'aria2c -c -x 1 -s 1 -d "{dir_}" "{url_dl}"'
304
+ os.system(aria2_command)
305
+ else:
306
+ output_path = os.path.join(dir_, civ_filename)
307
+ if not os.path.exists(output_path):
308
+ aria2_command = (
309
+ f'aria2c --console-log-level=error --summary-interval=10 -c -x 16 '
310
+ f'-k 1M -s 16 -d "{dir_}" -o "{civ_filename}" "{url_dl}"'
311
+ )
312
+ os.system(aria2_command)
313
+
314
+ return output_path
315
+
316
+
317
+ def drive_down(url, dir_):
318
+ import gdown
319
+
320
+ output_path = None
321
+
322
+ drive_id, _ = gdown.parse_url.parse_url(url, warning=False)
323
+ dir_files = os.listdir(dir_)
324
+
325
+ for dfile in dir_files:
326
+ if drive_id in dfile:
327
+ output_path = os.path.join(dir_, dfile)
328
+ break
329
+
330
+ if not output_path:
331
+ original_path = gdown.download(url, f"{dir_}/", fuzzy=True)
332
+
333
+ dir_name, base_name = os.path.split(original_path)
334
+ name, ext = base_name.rsplit(".", 1)
335
+ new_name = f"{name}_{drive_id}.{ext}"
336
+ output_path = os.path.join(dir_name, new_name)
337
+
338
+ os.rename(original_path, output_path)
339
+
340
+ return output_path
341
+
342
+
343
+ def hf_down(url, dir_, hf_token, romanize):
344
+ url = url.replace("?download=true", "")
345
+ # url = urllib.parse.quote(url, safe=':/') # fix encoding
346
+
347
+ filename = unidecode(url.split('/')[-1]) if romanize else url.split('/')[-1]
348
+ output_path = os.path.join(dir_, filename)
349
+
350
+ if os.path.exists(output_path):
351
+ return output_path
352
+
353
+ if "/blob/" in url:
354
+ url = url.replace("/blob/", "/resolve/")
355
+
356
+ if hf_token:
357
+ user_header = f'"Authorization: Bearer {hf_token}"'
358
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {dir_} -o {filename}")
359
+ else:
360
+ os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {dir_} -o {filename}")
361
+
362
+ return output_path
363
+
364
+
365
+ def download_things(directory, url, hf_token="", civitai_api_key="", romanize=False):
366
+ url = url.strip()
367
+ downloaded_file_path = None
368
+
369
+ if "drive.google.com" in url:
370
+ downloaded_file_path = drive_down(url, directory)
371
+ elif "huggingface.co" in url:
372
+ downloaded_file_path = hf_down(url, directory, hf_token, romanize)
373
+ elif "civitai.com" in url:
374
+ if not civitai_api_key:
375
+ msg = "You need an API key to download Civitai models."
376
+ print(f"\033[91m{msg}\033[0m")
377
+ gr.Warning(msg)
378
+ return None
379
+
380
+ url, civ_filename, civ_page = get_civit_params(url)
381
+ if civ_page and not IS_ZERO_GPU:
382
+ print(f"\033[92mCivitai model: {civ_filename} [page: {civ_page}]\033[0m")
383
+
384
+ downloaded_file_path, civ_filename = civ_redirect_down(url, directory, civitai_api_key, romanize, civ_filename)
385
+
386
+ if not downloaded_file_path:
387
+ msg = (
388
+ "Download failed.\n"
389
+ "If this is due to an API limit, generating a new API key may resolve the issue.\n"
390
+ "Attempting to download using the old method..."
391
+ )
392
+ print(msg)
393
+ gr.Warning(msg)
394
+ downloaded_file_path = civ_api_down(url, directory, civitai_api_key, civ_filename)
395
+ else:
396
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
397
+
398
+ return downloaded_file_path
399
+
400
+
401
+ def get_model_list(directory_path):
402
+ model_list = []
403
+ valid_extensions = {'.ckpt', '.pt', '.pth', '.safetensors', '.bin'}
404
+
405
+ for filename in os.listdir(directory_path):
406
+ if os.path.splitext(filename)[1] in valid_extensions:
407
+ # name_without_extension = os.path.splitext(filename)[0]
408
+ file_path = os.path.join(directory_path, filename)
409
+ # model_list.append((name_without_extension, file_path))
410
+ model_list.append(file_path)
411
+ print('\033[34mFILE: ' + file_path + '\033[0m')
412
+ return model_list
413
+
414
+
415
+ def extract_parameters(input_string):
416
+ parameters = {}
417
+ input_string = input_string.replace("\n", "")
418
+
419
+ if "Negative prompt:" not in input_string:
420
+ if "Steps:" in input_string:
421
+ input_string = input_string.replace("Steps:", "Negative prompt: Steps:")
422
+ else:
423
+ msg = "Generation data is invalid."
424
+ gr.Warning(msg)
425
+ print(msg)
426
+ parameters["prompt"] = input_string
427
+ return parameters
428
+
429
+ parm = input_string.split("Negative prompt:")
430
+ parameters["prompt"] = parm[0].strip()
431
+ if "Steps:" not in parm[1]:
432
+ parameters["neg_prompt"] = parm[1].strip()
433
+ return parameters
434
+ parm = parm[1].split("Steps:")
435
+ parameters["neg_prompt"] = parm[0].strip()
436
+ input_string = "Steps:" + parm[1]
437
+
438
+ # Extracting Steps
439
+ steps_match = re.search(r'Steps: (\d+)', input_string)
440
+ if steps_match:
441
+ parameters['Steps'] = int(steps_match.group(1))
442
+
443
+ # Extracting Size
444
+ size_match = re.search(r'Size: (\d+x\d+)', input_string)
445
+ if size_match:
446
+ parameters['Size'] = size_match.group(1)
447
+ width, height = map(int, parameters['Size'].split('x'))
448
+ parameters['width'] = width
449
+ parameters['height'] = height
450
+
451
+ # Extracting other parameters
452
+ other_parameters = re.findall(r'([^,:]+): (.*?)(?=, [^,:]+:|$)', input_string)
453
+ for param in other_parameters:
454
+ parameters[param[0].strip()] = param[1].strip('"')
455
+
456
+ return parameters
457
+
458
+
459
+ def get_my_lora(link_url, romanize):
460
+ l_name = ""
461
+ for url in [url.strip() for url in link_url.split(',')]:
462
+ if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
463
+ l_name = download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY, romanize)
464
+ new_lora_model_list = get_model_list(DIRECTORY_LORAS)
465
+ new_lora_model_list.insert(0, "None")
466
+ new_lora_model_list = new_lora_model_list + DIFFUSERS_FORMAT_LORAS
467
+ msg_lora = "Downloaded"
468
+ if l_name:
469
+ msg_lora += f": <b>{l_name}</b>"
470
+ print(msg_lora)
471
+
472
+ try:
473
+ # Works with non-Civitai loras.
474
+ json_data = read_safetensors_header_from_file(l_name)
475
+ metadata_lora = LoraHeaderInformation(json_data)
476
+ msg_lora += "<br>" + metadata_lora.to_html()
477
+ except Exception:
478
+ pass
479
+
480
+ return gr.update(
481
+ choices=new_lora_model_list
482
+ ), gr.update(
483
+ choices=new_lora_model_list
484
+ ), gr.update(
485
+ choices=new_lora_model_list
486
+ ), gr.update(
487
+ choices=new_lora_model_list
488
+ ), gr.update(
489
+ choices=new_lora_model_list
490
+ ), gr.update(
491
+ choices=new_lora_model_list
492
+ ), gr.update(
493
+ choices=new_lora_model_list
494
+ ), gr.update(
495
+ value=msg_lora
496
+ )
497
+
498
+
499
+ def info_html(json_data, title, subtitle):
500
+ return f"""
501
+ <div style='padding: 0; border-radius: 10px;'>
502
+ <p style='margin: 0; font-weight: bold;'>{title}</p>
503
+ <details>
504
+ <summary>Details</summary>
505
+ <p style='margin: 0; font-weight: bold;'>{subtitle}</p>
506
+ </details>
507
+ </div>
508
+ """
509
+
510
+
511
+ def get_model_type(repo_id: str):
512
+ api = HfApi(token=os.environ.get("HF_TOKEN")) # if use private or gated model
513
+ default = "SD 1.5"
514
+ try:
515
+ if os.path.exists(repo_id):
516
+ tag, _, _, _ = checkpoint_model_type(repo_id)
517
+ return DIFFUSECRAFT_CHECKPOINT_NAME[tag]
518
+ else:
519
+ model = api.model_info(repo_id=repo_id, timeout=5.0)
520
+ tags = model.tags
521
+ for tag in tags:
522
+ if tag in MODEL_TYPE_CLASS.keys():
523
+ return MODEL_TYPE_CLASS.get(tag, default)
524
+
525
+ except Exception:
526
+ return default
527
+ return default
528
+
529
+
530
+ def restart_space(repo_id: str, factory_reboot: bool):
531
+ api = HfApi(token=os.environ.get("HF_TOKEN"))
532
+ try:
533
+ runtime = api.get_space_runtime(repo_id=repo_id)
534
+ if runtime.stage == "RUNNING":
535
+ api.restart_space(repo_id=repo_id, factory_reboot=factory_reboot)
536
+ print(f"Restarting space: {repo_id}")
537
+ else:
538
+ print(f"Space {repo_id} is in stage: {runtime.stage}")
539
+ except Exception as e:
540
+ print(e)
541
+
542
+
543
+ def extract_exif_data(image):
544
+ if image is None:
545
+ return ""
546
+
547
+ try:
548
+ metadata_keys = ['parameters', 'metadata', 'prompt', 'Comment']
549
+
550
+ for key in metadata_keys:
551
+ if key in image.info:
552
+ return image.info[key]
553
+
554
+ return str(image.info)
555
+
556
+ except Exception as e:
557
+ return f"Error extracting metadata: {str(e)}"
558
+
559
+
560
+ def create_mask_now(img, invert):
561
+ import numpy as np
562
+ import time
563
+
564
+ time.sleep(0.5)
565
+
566
+ transparent_image = img["layers"][0]
567
+
568
+ # Extract the alpha channel
569
+ alpha_channel = np.array(transparent_image)[:, :, 3]
570
+
571
+ # Create a binary mask by thresholding the alpha channel
572
+ binary_mask = alpha_channel > 1
573
+
574
+ if invert:
575
+ print("Invert")
576
+ # Invert the binary mask so that the drawn shape is white and the rest is black
577
+ binary_mask = np.invert(binary_mask)
578
+
579
+ # Convert the binary mask to a 3-channel RGB mask
580
+ rgb_mask = np.stack((binary_mask,) * 3, axis=-1)
581
+
582
+ # Convert the mask to uint8
583
+ rgb_mask = rgb_mask.astype(np.uint8) * 255
584
+
585
+ return img["background"], rgb_mask
586
+
587
+
588
+ def download_diffuser_repo(repo_name: str, model_type: str, revision: str = "main", token=True):
589
+
590
+ variant = None
591
+ if token is True and not os.environ.get("HF_TOKEN"):
592
+ token = None
593
+
594
+ if model_type == "SDXL":
595
+ info = model_info_data(
596
+ repo_name,
597
+ token=token,
598
+ revision=revision,
599
+ timeout=5.0,
600
+ )
601
+
602
+ filenames = {sibling.rfilename for sibling in info.siblings}
603
+ model_filenames, variant_filenames = variant_compatible_siblings(
604
+ filenames, variant="fp16"
605
+ )
606
+
607
+ if len(variant_filenames):
608
+ variant = "fp16"
609
+
610
+ if model_type == "FLUX":
611
+ cached_folder = snapshot_download(
612
+ repo_id=repo_name,
613
+ allow_patterns="transformer/*"
614
+ )
615
+ else:
616
+ cached_folder = DiffusionPipeline.download(
617
+ pretrained_model_name=repo_name,
618
+ force_download=False,
619
+ token=token,
620
+ revision=revision,
621
+ # mirror="https://hf-mirror.com",
622
+ variant=variant,
623
+ use_safetensors=True,
624
+ trust_remote_code=False,
625
+ timeout=5.0,
626
+ )
627
+
628
+ if isinstance(cached_folder, PosixPath):
629
+ cached_folder = cached_folder.as_posix()
630
+
631
+ # Task model
632
+ # from huggingface_hub import hf_hub_download
633
+ # hf_hub_download(
634
+ # task_model,
635
+ # filename="diffusion_pytorch_model.safetensors", # fix fp16 variant
636
+ # )
637
+
638
+ return cached_folder
639
+
640
+
641
+ def get_folder_size_gb(folder_path):
642
+ result = subprocess.run(["du", "-s", folder_path], capture_output=True, text=True)
643
+
644
+ total_size_kb = int(result.stdout.split()[0])
645
+ total_size_gb = total_size_kb / (1024 ** 2)
646
+
647
+ return total_size_gb
648
+
649
+
650
+ def get_used_storage_gb(path_storage=STORAGE_ROOT):
651
+ try:
652
+ used_gb = get_folder_size_gb(path_storage)
653
+ print(f"Used Storage: {used_gb:.2f} GB")
654
+ except Exception as e:
655
+ used_gb = 999
656
+ print(f"Error while retrieving the used storage: {e}.")
657
+
658
+ return used_gb
659
+
660
+
661
+ def delete_model(removal_candidate):
662
+ print(f"Removing: {removal_candidate}")
663
+
664
+ if os.path.exists(removal_candidate):
665
+ os.remove(removal_candidate)
666
+ else:
667
+ diffusers_model = f"{CACHE_HF}{DIRECTORY_MODELS}--{removal_candidate.replace('/', '--')}"
668
+ if os.path.isdir(diffusers_model):
669
+ shutil.rmtree(diffusers_model)
670
+
671
+
672
+ def clear_hf_cache():
673
+ """
674
+ Clears the entire Hugging Face cache at ~/.cache/huggingface.
675
+ Hugging Face will re-download models as needed later.
676
+ """
677
+ try:
678
+ if os.path.exists(CACHE_HF):
679
+ shutil.rmtree(CACHE_HF, ignore_errors=True)
680
+ print(f"Hugging Face cache cleared: {CACHE_HF}")
681
+ else:
682
+ print(f"No Hugging Face cache found at: {CACHE_HF}")
683
+ except Exception as e:
684
+ print(f"Error clearing Hugging Face cache: {e}")
685
+
686
+
687
+ def progress_step_bar(step, total):
688
+ # Calculate the percentage for the progress bar width
689
+ percentage = min(100, ((step / total) * 100))
690
+
691
+ return f"""
692
+ <div style="position: relative; width: 100%; background-color: gray; border-radius: 5px; overflow: hidden;">
693
+ <div style="width: {percentage}%; height: 17px; background-color: #800080; transition: width 0.5s;"></div>
694
+ <div style="position: absolute; width: 100%; text-align: center; color: white; top: 0; line-height: 19px; font-size: 13px;">
695
+ {int(percentage)}%
696
+ </div>
697
+ </div>
698
+ """
699
+
700
+
701
+ def html_template_message(msg):
702
+ return f"""
703
+ <div style="position: relative; width: 100%; background-color: gray; border-radius: 5px; overflow: hidden;">
704
+ <div style="width: 0%; height: 17px; background-color: #800080; transition: width 0.5s;"></div>
705
+ <div style="position: absolute; width: 100%; text-align: center; color: white; top: 0; line-height: 19px; font-size: 14px; font-weight: bold; text-shadow: 1px 1px 2px black;">
706
+ {msg}
707
+ </div>
708
+ </div>
709
+ """
710
+
711
+
712
+ def escape_html(text):
713
+ """Escapes HTML special characters in the input text."""
714
+ return text.replace("<", "&lt;").replace(">", "&gt;").replace("\n", "<br>")