Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from huggingface_hub import HfApi, hf_hub_url | |
| from huggingface_hub.hf_api import RepoFile | |
| import os | |
| from pathlib import Path | |
| import gc | |
| import requests | |
| from requests.adapters import HTTPAdapter | |
| from urllib3.util import Retry | |
| import urllib | |
| from utils import (get_token, set_token, is_repo_exists, get_user_agent, get_download_file, | |
| list_uniq, list_sub, duplicate_hf_repo, HF_SUBFOLDER_NAME, get_state, set_state) | |
| import re | |
| from PIL import Image | |
| import json | |
| import pandas as pd | |
| import tempfile | |
| import hashlib | |
| TEMP_DIR = tempfile.mkdtemp() | |
| def parse_urls(s): | |
| url_pattern = "https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+" | |
| try: | |
| urls = re.findall(url_pattern, s) | |
| return list(urls) | |
| except Exception: | |
| return [] | |
| def parse_repos(s): | |
| repo_pattern = r'[^\w_\-\.]?([\w_\-\.]+/[\w_\-\.]+)[^\w_\-\.]?' | |
| try: | |
| s = re.sub("https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+", "", s) | |
| repos = re.findall(repo_pattern, s) | |
| return list(repos) | |
| except Exception: | |
| return [] | |
| def to_urls(l: list[str]): | |
| return "\n".join(l) | |
| def uniq_urls(s): | |
| return to_urls(list_uniq(parse_urls(s) + parse_repos(s))) | |
| def upload_safetensors_to_repo(filename, repo_id, repo_type, is_private, progress=gr.Progress(track_tqdm=True)): | |
| output_filename = Path(filename).name | |
| hf_token = get_token() | |
| api = HfApi(token=hf_token) | |
| try: | |
| if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private) | |
| progress(0, desc=f"Start uploading... {filename} to {repo_id}") | |
| api.upload_file(path_or_fileobj=filename, path_in_repo=output_filename, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id) | |
| progress(1, desc="Uploaded.") | |
| url = hf_hub_url(repo_id=repo_id, repo_type=repo_type, filename=output_filename) | |
| except Exception as e: | |
| print(f"Error: Failed to upload to {repo_id}. {e}") | |
| gr.Warning(f"Error: Failed to upload to {repo_id}. {e}") | |
| return None | |
| finally: | |
| if Path(filename).exists(): Path(filename).unlink() | |
| return url | |
| def get_repo_hashes(repo_id: str, repo_type: str="model"): | |
| hf_token = get_token() | |
| api = HfApi(token=hf_token) | |
| hashes = [] | |
| try: | |
| if not api.repo_exists(repo_id=repo_id, repo_type=repo_type, token=hf_token): return hashes | |
| tree = api.list_repo_tree(repo_id=repo_id, repo_type=repo_type, token=hf_token) | |
| for f in tree: | |
| if not isinstance(f, RepoFile) or f.lfs is None or f.lfs.get("sha256", None) is None: continue | |
| hashes.append(f.lfs["sha256"]) | |
| except Exception as e: | |
| print(e) | |
| finally: | |
| return hashes | |
| def get_civitai_sha256(dl_url: str, api_key=""): | |
| def is_invalid_file(qs: dict, json: dict, k: str): | |
| return k in qs.keys() and qs[k][0] != json.get(k, None) and json.get(k, None) is not None | |
| if "https://civitai.com/api/download/models/" not in dl_url: return None | |
| user_agent = get_user_agent() | |
| headers = {'User-Agent': user_agent, 'content-type': 'application/json'} | |
| if api_key: headers['Authorization'] = f'Bearer {{{api_key}}}' | |
| base_url = 'https://civitai.com/api/v1/model-versions/' | |
| params = {} | |
| session = requests.Session() | |
| retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504]) | |
| session.mount("https://", HTTPAdapter(max_retries=retries)) | |
| m = re.match(r'https://civitai.com/api/download/models/(\d+)\??(.+)?', dl_url) | |
| if m is None: return None | |
| url = base_url + m.group(1) | |
| qs = urllib.parse.parse_qs(m.group(2)) | |
| if "type" not in qs.keys(): qs["type"] = ["Model"] | |
| try: | |
| r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15)) | |
| if not r.ok: return None | |
| json = dict(r.json()) | |
| if "files" not in json.keys() or not isinstance(json["files"], list): return None | |
| hash = None | |
| for d in json["files"]: | |
| if is_invalid_file(qs, d, "type") or is_invalid_file(qs, d, "format") or is_invalid_file(qs, d, "size") or is_invalid_file(qs, d, "fp"): continue | |
| hash = d["hashes"]["SHA256"].lower() | |
| break | |
| return hash | |
| except Exception as e: | |
| print(e) | |
| return None | |
| def is_same_file(filename: str, cmp_sha256: str, cmp_size: int): | |
| if cmp_sha256: | |
| sha256_hash = hashlib.sha256() | |
| with open(filename, "rb") as f: | |
| for byte_block in iter(lambda: f.read(4096), b""): | |
| sha256_hash.update(byte_block) | |
| sha256 = sha256_hash.hexdigest() | |
| else: sha256 = "" | |
| size = os.path.getsize(filename) | |
| if size == cmp_size and sha256 == cmp_sha256: return True | |
| else: return False | |
| def get_safe_filename(filename, repo_id, repo_type): | |
| hf_token = get_token() | |
| api = HfApi(token=hf_token) | |
| new_filename = filename | |
| try: | |
| i = 1 | |
| while api.file_exists(repo_id=repo_id, filename=Path(new_filename).name, repo_type=repo_type, token=hf_token): | |
| infos = api.get_paths_info(repo_id=repo_id, paths=[Path(new_filename).name], repo_type=repo_type, token=hf_token) | |
| if infos and len(infos) == 1: | |
| repo_fs = infos[0].size | |
| repo_sha256 = infos[0].lfs.sha256 if infos[0].lfs is not None else "" | |
| if is_same_file(filename, repo_sha256, repo_fs): break | |
| new_filename = str(Path(Path(filename).parent, f"{Path(filename).stem}_{i}{Path(filename).suffix}")) | |
| i += 1 | |
| if filename != new_filename: | |
| print(f"{Path(filename).name} is already exists but file content is different. renaming to {Path(new_filename).name}.") | |
| Path(filename).rename(new_filename) | |
| except Exception as e: | |
| print(f"Error occured when renaming {filename}. {e}") | |
| finally: | |
| return new_filename | |
| def download_file(dl_url, civitai_key, progress=gr.Progress(track_tqdm=True)): | |
| download_dir = TEMP_DIR | |
| progress(0, desc=f"Start downloading... {dl_url}") | |
| output_filename = get_download_file(download_dir, dl_url, civitai_key) | |
| return output_filename | |
| def save_civitai_info(dl_url, filename, civitai_key="", progress=gr.Progress(track_tqdm=True)): | |
| json_str, html_str, image_path = get_civitai_json(dl_url, True, filename, civitai_key) | |
| if not json_str: return "", "", "" | |
| json_path = str(Path(TEMP_DIR, Path(filename).stem + ".json")) | |
| html_path = str(Path(TEMP_DIR, Path(filename).stem + ".html")) | |
| try: | |
| with open(json_path, 'w') as f: | |
| json.dump(json_str, f, indent=2) | |
| with open(html_path, mode='w', encoding="utf-8") as f: | |
| f.write(html_str) | |
| return json_path, html_path, image_path | |
| except Exception as e: | |
| print(f"Error: Failed to save info file {json_path}, {html_path} {e}") | |
| return "", "", "" | |
| def upload_info_to_repo(dl_url, filename, repo_id, repo_type, is_private, civitai_key="", progress=gr.Progress(track_tqdm=True)): | |
| def upload_file(api, filename, repo_id, repo_type, hf_token): | |
| if not Path(filename).exists(): return | |
| api.upload_file(path_or_fileobj=filename, path_in_repo=Path(filename).name, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id) | |
| Path(filename).unlink() | |
| hf_token = get_token() | |
| api = HfApi(token=hf_token) | |
| try: | |
| if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private) | |
| progress(0, desc=f"Downloading info... {filename}") | |
| json_path, html_path, image_path = save_civitai_info(dl_url, filename, civitai_key) | |
| progress(0, desc=f"Start uploading info... {filename} to {repo_id}") | |
| if not json_path: return | |
| else: upload_file(api, json_path, repo_id, repo_type, hf_token) | |
| if html_path: upload_file(api, html_path, repo_id, repo_type, hf_token) | |
| if image_path: upload_file(api, image_path, repo_id, repo_type, hf_token) | |
| progress(1, desc="Info uploaded.") | |
| return | |
| except Exception as e: | |
| print(f"Error: Failed to upload info to {repo_id}. {e}") | |
| gr.Warning(f"Error: Failed to upload info to {repo_id}. {e}") | |
| return | |
| def download_civitai(dl_url, civitai_key, hf_token, urls, | |
| newrepo_id, repo_type="model", is_private=True, is_info=False, is_rename=True, progress=gr.Progress(track_tqdm=True)): | |
| if hf_token: set_token(hf_token) | |
| else: set_token(os.getenv("HF_TOKEN", False)) # default huggingface write token | |
| if not civitai_key: civitai_key = os.environ.get("CIVITAI_API_KEY") # default Civitai API key | |
| if not newrepo_id: newrepo_id = os.environ.get("HF_REPO") # default repo to upload | |
| if not get_token() or not civitai_key: raise gr.Error("HF write token and Civitai API key is required.") | |
| if not urls: urls = [] | |
| dl_urls = parse_urls(dl_url) | |
| remain_urls = dl_urls.copy() | |
| hashes = set(get_repo_hashes(newrepo_id, repo_type)) | |
| try: | |
| md = f'### Your repo: [{newrepo_id}]({"https://huggingface.co/datasets/" if repo_type == "dataset" else "https://huggingface.co/"}{newrepo_id})\n' | |
| for u in dl_urls: | |
| if get_civitai_sha256(u, civitai_key) in hashes: | |
| print(f"{u} is already exitsts. skipping.") | |
| remain_urls.remove(u) | |
| md += f"- Skipped [{str(u)}]({str(u)})\n" | |
| continue | |
| file = download_file(u, civitai_key) | |
| if not Path(file).exists() or not Path(file).is_file(): continue | |
| if is_rename: file = get_safe_filename(file, newrepo_id, repo_type) | |
| url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private) | |
| if url: | |
| if is_info: upload_info_to_repo(u, file, newrepo_id, repo_type, is_private, civitai_key) | |
| urls.append(url) | |
| remain_urls.remove(u) | |
| md += f"- Uploaded [{str(u)}]({str(u)})\n" | |
| dp_repos = parse_repos(dl_url) | |
| for r in dp_repos: | |
| url = duplicate_hf_repo(r, newrepo_id, "model", repo_type, is_private, HF_SUBFOLDER_NAME[1]) | |
| if url: urls.append(url) | |
| return gr.update(value=urls, choices=urls), gr.update(value=md), gr.update(value="\n".join(remain_urls), visible=False) | |
| except Exception as e: | |
| gr.Info(f"Error occured: {e}") | |
| return gr.update(value=urls, choices=urls), gr.update(value=md), gr.update(value="\n".join(remain_urls), visible=True) | |
| finally: | |
| gc.collect() | |
| CIVITAI_TYPE = ["Checkpoint", "TextualInversion", "Hypernetwork", "AestheticGradient", "LORA", "LoCon", "DoRA", | |
| "Controlnet", "Upscaler", "MotionModule", "VAE", "Poses", "Wildcards", "Workflows", "Other"] | |
| CIVITAI_FILETYPE = ["Model", "VAE", "Config", "Training Data"] | |
| CIVITAI_BASEMODEL = ["Pony", "Illustrious", "SDXL 1.0", "SD 1.5", "Flux.1 D", "Flux.1 S", "SD 3.5"] | |
| #CIVITAI_SORT = ["Highest Rated", "Most Downloaded", "Newest"] | |
| CIVITAI_SORT = ["Highest Rated", "Most Downloaded", "Most Liked", "Most Discussed", "Most Collected", "Most Buzz", "Newest"] | |
| CIVITAI_PERIOD = ["AllTime", "Year", "Month", "Week", "Day"] | |
| def search_on_civitai(query: str, types: list[str], allow_model: list[str] = [], limit: int = 100, | |
| sort: str = "Highest Rated", period: str = "AllTime", tag: str = "", user: str = "", page: int = 1, | |
| filetype: list[str] = [], api_key: str = "", progress=gr.Progress(track_tqdm=True)): | |
| user_agent = get_user_agent() | |
| headers = {'User-Agent': user_agent, 'content-type': 'application/json'} | |
| if api_key: headers['Authorization'] = f'Bearer {{{api_key}}}' | |
| base_url = 'https://civitai.com/api/v1/models' | |
| params = {'sort': sort, 'period': period, 'limit': int(limit), 'nsfw': 'true'} | |
| if len(types) != 0: params["types"] = types | |
| if query: params["query"] = query | |
| if tag: params["tag"] = tag | |
| if user: params["username"] = user | |
| if page != 0: params["page"] = int(page) | |
| session = requests.Session() | |
| retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504]) | |
| session.mount("https://", HTTPAdapter(max_retries=retries)) | |
| rs = [] | |
| try: | |
| if page == 0: | |
| progress(0, desc="Searching page 1...") | |
| print("Searching page 1...") | |
| r = session.get(base_url, params=params | {'page': 1}, headers=headers, stream=True, timeout=(7.0, 30)) | |
| rs.append(r) | |
| if r.ok: | |
| json = r.json() | |
| next_url = json['metadata']['nextPage'] if 'metadata' in json and 'nextPage' in json['metadata'] else None | |
| i = 2 | |
| while(next_url is not None): | |
| progress(0, desc=f"Searching page {i}...") | |
| print(f"Searching page {i}...") | |
| r = session.get(next_url, headers=headers, stream=True, timeout=(7.0, 30)) | |
| rs.append(r) | |
| if r.ok: | |
| json = r.json() | |
| next_url = json['metadata']['nextPage'] if 'metadata' in json and 'nextPage' in json['metadata'] else None | |
| else: next_url = None | |
| i += 1 | |
| else: | |
| progress(0, desc="Searching page 1...") | |
| print("Searching page 1...") | |
| r = session.get(base_url, params=params, headers=headers, stream=True, timeout=(7.0, 30)) | |
| rs.append(r) | |
| except requests.exceptions.ConnectTimeout: | |
| print("Request timed out.") | |
| except Exception as e: | |
| print(e) | |
| items = [] | |
| for r in rs: | |
| if not r.ok: continue | |
| json = r.json() | |
| if 'items' not in json: continue | |
| for j in json['items']: | |
| for model in j['modelVersions']: | |
| item = {} | |
| if len(allow_model) != 0 and model['baseModel'] not in set(allow_model): continue | |
| item['name'] = j['name'] | |
| item['creator'] = j['creator']['username'] if 'creator' in j.keys() and 'username' in j['creator'].keys() else "" | |
| item['tags'] = j['tags'] if 'tags' in j.keys() else [] | |
| item['model_name'] = model['name'] if 'name' in model.keys() else "" | |
| item['base_model'] = model['baseModel'] if 'baseModel' in model.keys() else "" | |
| item['description'] = model['description'] if 'description' in model.keys() else "" | |
| item['md'] = "" | |
| if 'images' in model.keys() and len(model["images"]) != 0: | |
| item['img_url'] = model["images"][0]["url"] | |
| item['md'] += f'<img src="{model["images"][0]["url"]}#float" alt="thumbnail" width="150" height="240"><br>' | |
| else: item['img_url'] = "/home/user/app/null.png" | |
| item['md'] += f'''Model URL: [https://civitai.com/models/{j["id"]}](https://civitai.com/models/{j["id"]})<br>Model Name: {item["name"]}<br> | |
| Creator: {item["creator"]}<br>Tags: {", ".join(item["tags"])}<br>Base Model: {item["base_model"]}<br>Description: {item["description"]}''' | |
| if 'files' in model.keys(): | |
| for f in model['files']: | |
| i = item.copy() | |
| i['dl_url'] = f['downloadUrl'] | |
| if len(filetype) != 0 and f['type'] not in set(filetype): continue | |
| items.append(i) | |
| else: | |
| item['dl_url'] = model['downloadUrl'] | |
| items.append(item) | |
| return items if len(items) > 0 else None | |
| def search_civitai(query, types, base_model=[], sort=CIVITAI_SORT[0], period=CIVITAI_PERIOD[0], tag="", user="", limit=100, page=1, | |
| filetype=[], api_key="", gallery=[], state={}, progress=gr.Progress(track_tqdm=True)): | |
| civitai_last_results = {} | |
| set_state(state, "civitai_last_choices", [("", "")]) | |
| set_state(state, "civitai_last_gallery", []) | |
| set_state(state, "civitai_last_results", civitai_last_results) | |
| results_info = "No item found." | |
| items = search_on_civitai(query, types, base_model, int(limit), sort, period, tag, user, int(page), filetype, api_key) | |
| if not items: return gr.update(choices=[("", "")], value=[], visible=True),\ | |
| gr.update(value="", visible=False), gr.update(), gr.update(), gr.update(), gr.update(), results_info, state | |
| choices = [] | |
| gallery = [] | |
| for item in items: | |
| base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model'] | |
| name = f"{item['name']} (for {base_model_name} / By: {item['creator']})" | |
| value = item['dl_url'] | |
| choices.append((name, value)) | |
| gallery.append((item['img_url'], name)) | |
| civitai_last_results[value] = item | |
| if len(choices) >= 1: results_info = f"{int(len(choices))} items found." | |
| else: choices = [("", "")] | |
| md = "" | |
| set_state(state, "civitai_last_choices", choices) | |
| set_state(state, "civitai_last_gallery", gallery) | |
| set_state(state, "civitai_last_results", civitai_last_results) | |
| return gr.update(choices=choices, value=[], visible=True), gr.update(value=md, visible=True),\ | |
| gr.update(), gr.update(), gr.update(value=gallery), gr.update(choices=choices, value=[]), results_info, state | |
| def get_civitai_json(dl_url: str, is_html: bool=False, image_baseurl: str="", api_key=""): | |
| if not image_baseurl: image_baseurl = dl_url | |
| default = ("", "", "") if is_html else "" | |
| if "https://civitai.com/api/download/models/" not in dl_url: return default | |
| user_agent = get_user_agent() | |
| headers = {'User-Agent': user_agent, 'content-type': 'application/json'} | |
| if api_key: headers['Authorization'] = f'Bearer {{{api_key}}}' | |
| base_url = 'https://civitai.com/api/v1/model-versions/' | |
| params = {} | |
| session = requests.Session() | |
| retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504]) | |
| session.mount("https://", HTTPAdapter(max_retries=retries)) | |
| model_id = re.sub('https://civitai.com/api/download/models/(\\d+)(?:.+)?', '\\1', dl_url) | |
| url = base_url + model_id | |
| #url = base_url + str(dl_url.split("/")[-1]) | |
| try: | |
| r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15)) | |
| if not r.ok: return default | |
| json = dict(r.json()).copy() | |
| html = "" | |
| image = "" | |
| if "modelId" in json.keys(): | |
| url = f"https://civitai.com/models/{json['modelId']}" | |
| r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15)) | |
| if not r.ok: return json, html, image | |
| html = r.text | |
| if 'images' in json.keys() and len(json["images"]) != 0: | |
| url = json["images"][0]["url"] | |
| r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15)) | |
| if not r.ok: return json, html, image | |
| image_temp = str(Path(TEMP_DIR, "image" + Path(url.split("/")[-1]).suffix)) | |
| image = str(Path(TEMP_DIR, Path(image_baseurl.split("/")[-1]).stem + ".png")) | |
| with open(image_temp, 'wb') as f: | |
| f.write(r.content) | |
| Image.open(image_temp).convert('RGBA').save(image) | |
| return json, html, image | |
| except Exception as e: | |
| print(e) | |
| return default | |
| def get_civitai_tag(): | |
| default = [""] | |
| user_agent = get_user_agent() | |
| headers = {'User-Agent': user_agent, 'content-type': 'application/json'} | |
| base_url = 'https://civitai.com/api/v1/tags' | |
| params = {'limit': 200} | |
| session = requests.Session() | |
| retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504]) | |
| session.mount("https://", HTTPAdapter(max_retries=retries)) | |
| url = base_url | |
| try: | |
| r = session.get(url, params=params, headers=headers, stream=True, timeout=(7.0, 15)) | |
| if not r.ok: return default | |
| j = dict(r.json()).copy() | |
| if "items" not in j.keys(): return default | |
| items = [] | |
| for item in j["items"]: | |
| items.append([str(item.get("name", "")), int(item.get("modelCount", 0))]) | |
| df = pd.DataFrame(items) | |
| df.sort_values(1, ascending=False) | |
| tags = df.values.tolist() | |
| tags = [""] + [l[0] for l in tags] | |
| return tags | |
| except Exception as e: | |
| print(e) | |
| return default | |
| def select_civitai_item(results: list[str], state: dict): | |
| json = {} | |
| if "http" not in "".join(results) or len(results) == 0: return gr.update(value="", visible=True), gr.update(value=json, visible=False), state | |
| result = get_state(state, "civitai_last_results") | |
| last_selects = get_state(state, "civitai_last_selects") | |
| selects = list_sub(results, last_selects if last_selects else []) | |
| md = result.get(selects[-1]).get('md', "") if result and isinstance(result, dict) and len(selects) > 0 else "" | |
| set_state(state, "civitai_last_selects", results) | |
| return gr.update(value=md, visible=True), gr.update(value=json, visible=False), state | |
| def add_civitai_item(results: list[str], dl_url: str): | |
| if "http" not in "".join(results): return gr.update(value=dl_url) | |
| new_url = dl_url if dl_url else "" | |
| for result in results: | |
| if "http" not in result: continue | |
| new_url += f"\n{result}" if new_url else f"{result}" | |
| new_url = uniq_urls(new_url) | |
| return gr.update(value=new_url) | |
| def select_civitai_all_item(button_name: str, state: dict): | |
| if button_name not in ["Select All", "Deselect All"]: return gr.update(value=button_name), gr.Update(visible=True) | |
| civitai_last_choices = get_state(state, "civitai_last_choices") | |
| selected = [t[1] for t in civitai_last_choices if t[1] != ""] if button_name == "Select All" else [] | |
| new_button_name = "Select All" if button_name == "Deselect All" else "Deselect All" | |
| return gr.update(value=new_button_name), gr.update(value=selected, choices=civitai_last_choices) | |
| def update_civitai_selection(evt: gr.SelectData, value: list[str], state: dict): | |
| try: | |
| civitai_last_choices = get_state(state, "civitai_last_choices") | |
| selected_index = evt.index | |
| selected = list_uniq([v for v in value if v != ""] + [civitai_last_choices[selected_index][1]]) | |
| return gr.update(value=selected) | |
| except Exception: | |
| return gr.update() | |
| def update_civitai_checkbox(selected: list[str]): | |
| return gr.update(value=selected) | |
| def from_civitai_checkbox(selected: list[str]): | |
| return gr.update(value=selected) | |