Spaces:
Runtime error
Runtime error
| import io | |
| import os | |
| import shutil | |
| import zipfile | |
| import gradio as gr | |
| import requests | |
| from huggingface_hub import create_repo, upload_folder, whoami | |
| from convert import convert_full_checkpoint | |
| MODELS_DIR = "models/" | |
| CKPT_FILE = MODELS_DIR + "model.ckpt" | |
| HF_MODEL_DIR = MODELS_DIR + "diffusers_model" | |
| ZIP_FILE = MODELS_DIR + "model.zip" | |
| def download_ckpt(url, out_path): | |
| with open(out_path, "wb") as out_file: | |
| with requests.get(url, stream=True) as r: | |
| r.raise_for_status() | |
| for chunk in r.iter_content(chunk_size=8192): | |
| out_file.write(chunk) | |
| def zip_model(model_path, zip_path): | |
| with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as zip_file: | |
| for root, dirs, files in os.walk(model_path): | |
| for file in files: | |
| zip_file.write( | |
| os.path.join(root, file), | |
| os.path.relpath( | |
| os.path.join(root, file), os.path.join(model_path, "..") | |
| ), | |
| ) | |
| def download_checkpoint_and_config(ckpt_url, config_url): | |
| ckpt_url = ckpt_url.strip() | |
| config_url = config_url.strip() | |
| if not ckpt_url.startswith("http://") and not ckpt_url.startswith("https://"): | |
| raise ValueError("Invalid checkpoint URL") | |
| if config_url.startswith("http://") or config_url.startswith("https://"): | |
| response = requests.get(config_url) | |
| response.raise_for_status() | |
| config_file = io.BytesIO(response.content) | |
| elif config_url != "": | |
| raise ValueError("Invalid config URL") | |
| else: | |
| config_file = open("original_config.yaml", "r") | |
| download_ckpt(ckpt_url, CKPT_FILE) | |
| return CKPT_FILE, config_file | |
| def convert_and_download(ckpt_url, config_url, scheduler_type, extract_ema): | |
| shutil.rmtree(MODELS_DIR, ignore_errors=True) | |
| os.makedirs(HF_MODEL_DIR) | |
| ckpt_path, config_file = download_checkpoint_and_config(ckpt_url, config_url) | |
| convert_full_checkpoint( | |
| ckpt_path, | |
| config_file, | |
| scheduler_type=scheduler_type, | |
| extract_ema=(extract_ema == "EMA"), | |
| output_path=HF_MODEL_DIR, | |
| ) | |
| zip_model(HF_MODEL_DIR, ZIP_FILE) | |
| return ZIP_FILE | |
| def convert_and_upload( | |
| ckpt_url, config_url, scheduler_type, extract_ema, token, model_name | |
| ): | |
| shutil.rmtree(MODELS_DIR, ignore_errors=True) | |
| os.makedirs(HF_MODEL_DIR) | |
| try: | |
| ckpt_path, config_file = download_checkpoint_and_config(ckpt_url, config_url) | |
| username = whoami(token)["name"] | |
| repo_name = f"{username}/{model_name}" | |
| repo_url = create_repo(repo_name, token=token, exist_ok=True) | |
| convert_full_checkpoint( | |
| ckpt_path, | |
| config_file, | |
| scheduler_type=scheduler_type, | |
| extract_ema=(extract_ema == "EMA"), | |
| output_path=HF_MODEL_DIR, | |
| ) | |
| upload_folder(repo_id=repo_name, folder_path=HF_MODEL_DIR, token=token, commit_message=f"Upload diffusers weights") | |
| except Exception as e: | |
| return f"#### Error: {e}" | |
| return f"#### Success! Model uploaded to [{repo_url}]({repo_url})" | |
| TTILE_IMAGE = """ | |
| <div | |
| style=" | |
| display: block; | |
| margin-left: auto; | |
| margin-right: auto; | |
| width: 50%; | |
| " | |
| > | |
| <img src="https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg"/> | |
| </div> | |
| """ | |
| TITLE = """ | |
| <div | |
| style=" | |
| display: inline-flex; | |
| align-items: center; | |
| text-align: center; | |
| max-width: 1400px; | |
| gap: 0.8rem; | |
| font-size: 2.2rem; | |
| " | |
| > | |
| <h1 style="font-weight: 900; margin-bottom: 10px; margin-top: 10px;"> | |
| Convert Stable Diffusion `.ckpt` files to Hugging Face Diffusers 🔥 | |
| </h1> | |
| </div> | |
| """ | |
| with gr.Blocks() as interface: | |
| gr.HTML(TTILE_IMAGE) | |
| gr.HTML(TITLE) | |
| gr.Markdown("We will perform all of the checkpoint surgery for you, and create a clean diffusers model!") | |
| gr.Markdown("This converter will also remove any pickled code from third-party checkpoints.") | |
| with gr.Row(): | |
| with gr.Column(scale=50): | |
| gr.Markdown("### 1. Paste a URL to your <model>.ckpt file") | |
| ckpt_url = gr.Textbox( | |
| max_lines=1, | |
| label="URL to <model>.ckpt", | |
| placeholder="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt", | |
| ) | |
| with gr.Column(scale=50): | |
| gr.Markdown("### (Optional) paste a URL to your <config>.yaml file") | |
| config_url = gr.Textbox( | |
| max_lines=1, | |
| label="URL to <config>.yaml", | |
| placeholder="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-inference.yaml", | |
| ) | |
| gr.Markdown( | |
| "**If you don't provide a config file, we'll try to use" | |
| " [v1-inference.yaml](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-inference.yaml).*" | |
| ) | |
| with gr.Accordion("Advanced Settings"): | |
| scheduler_type = gr.Dropdown( | |
| label="Choose a scheduler type (if not sure, keep the PNDM default)", | |
| choices=["PNDM", "K-LMS", "Euler", "EulerAncestral", "DDIM"], | |
| value="PNDM", | |
| ) | |
| extract_ema = gr.Radio( | |
| label=( | |
| "EMA weights usually yield higher quality images for inference." | |
| " Non-EMA weights are usually better to continue fine-tuning." | |
| ), | |
| choices=["EMA", "Non-EMA"], | |
| value="EMA", | |
| interactive=True, | |
| ) | |
| gr.Markdown("### 2. Choose what to do with the converted model") | |
| model_choice = gr.Radio( | |
| show_label=False, | |
| choices=[ | |
| "Download the model as an archive", | |
| "Host the model on the Hugging Face Hub", | |
| # "Submit a PR with the model for an existing Hub repository", | |
| ], | |
| type="index", | |
| value="Download the model as an archive", | |
| interactive=True, | |
| ) | |
| download_panel = gr.Column(visible=True) | |
| upload_panel = gr.Column(visible=False) | |
| # pr_panel = gr.Column(visible=False) | |
| model_choice.change( | |
| fn=lambda i: gr.update(visible=(i == 0)), | |
| inputs=model_choice, | |
| outputs=download_panel, | |
| ) | |
| model_choice.change( | |
| fn=lambda i: gr.update(visible=(i == 1)), | |
| inputs=model_choice, | |
| outputs=upload_panel, | |
| ) | |
| # model_choice.change( | |
| # fn=lambda i: gr.update(visible=(i == 2)), | |
| # inputs=model_choice, | |
| # outputs=pr_panel, | |
| # ) | |
| with download_panel: | |
| gr.Markdown("### 3. Convert and download") | |
| down_btn = gr.Button("Convert") | |
| output_file = gr.File( | |
| label="Download the converted model", | |
| type="binary", | |
| interactive=False, | |
| visible=True, | |
| ) | |
| down_btn.click( | |
| fn=convert_and_download, | |
| inputs=[ckpt_url, config_url, scheduler_type, extract_ema], | |
| outputs=output_file, | |
| ) | |
| with upload_panel: | |
| gr.Markdown("### 3. Convert and host on the Hub") | |
| gr.Markdown( | |
| "This will create a new repository if it doesn't exist yet, and upload the model to the Hugging Face Hub.\n\n" | |
| "Paste a WRITE token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)" | |
| " and make up a model name." | |
| ) | |
| up_token = gr.Textbox( | |
| max_lines=1, | |
| label="Hugging Face token", | |
| ) | |
| up_model_name = gr.Textbox( | |
| max_lines=1, | |
| label="Hub model name (e.g. `artistic-diffusion-v1`)", | |
| placeholder="my-awesome-model", | |
| ) | |
| upload_btn = gr.Button("Convert and upload") | |
| with gr.Box(): | |
| output_text = gr.Markdown() | |
| upload_btn.click( | |
| fn=convert_and_upload, | |
| inputs=[ | |
| ckpt_url, | |
| config_url, | |
| scheduler_type, | |
| extract_ema, | |
| up_token, | |
| up_model_name, | |
| ], | |
| outputs=output_text, | |
| ) | |
| # with pr_panel: | |
| # gr.Markdown("### 3. Convert and submit as a PR") | |
| # gr.Markdown( | |
| # "This will open a Pull Request on the original model repository, if it already exists on the Hub.\n\n" | |
| # "Paste a write-access token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)" | |
| # " and paste an existing model id from the Hub in the `username/model-name` form." | |
| # ) | |
| # pr_token = gr.Textbox( | |
| # max_lines=1, | |
| # label="Hugging Face token", | |
| # ) | |
| # pr_model_name = gr.Textbox( | |
| # max_lines=1, | |
| # label="Hub model name (e.g. `diffuser/artistic-diffusion-v1`)", | |
| # placeholder="diffuser/my-awesome-model", | |
| # ) | |
| # | |
| # btn = gr.Button("Convert and open a PR") | |
| # output = gr.Markdown(label="Output") | |
| interface.queue(concurrency_count=1) | |
| interface.launch() | |