silveroxides commited on
Commit
b970c25
·
1 Parent(s): ee0fa2a

add app.py and convertion code

Browse files
Files changed (5) hide show
  1. app.py +34 -0
  2. convert.py +73 -0
  3. hf_utils.py +50 -0
  4. requirements.txt +8 -0
  5. utils.py +6 -0
app.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from convert import convert
4
+
5
+ DESCRIPTION = """
6
+ The steps are the following:
7
+
8
+ - Create a new model repo for the converted model if you do not already have one.
9
+ - Paste a read-access token in the your_hf_token box from hf.co/settings/tokens. Read access is enough given that we will open a PR against your created repo.
10
+ - Input a model id (username/repo) which can be put in clipboard by clicking the copy icon ⧉ next to the title of the repo then paste in the model_id box.
11
+ - Input the filename from the root dir of the repo that you would like to convert which can be added to clipboard by clicking the filename and then the copy icon ⧉ next to file names title and input that to filename box.
12
+ - Paste the model id of your new repo in the your_model_id box.
13
+ - Click "Submit".
14
+ - That's it! You'll get feedback if it works or not, and if it worked, you'll get the URL of the opened PR 🔥.
15
+
16
+ ⚠️ If you encounter weird error messages, please have a look into the Logs and feel free to open a PR to correct the error messages.
17
+ """
18
+
19
+ demo = gr.Interface(
20
+ title="Convert any weights only .pt, .pth, .bin, .ckpt to .safetensors and open a PR",
21
+ description=DESCRIPTION,
22
+ flagging_mode="never",
23
+ article="placeholder",
24
+ inputs=[
25
+ gr.Text(max_lines=1, label="your_hf_token"),
26
+ gr.Text(max_lines=1, label="model_id"),
27
+ gr.Text(max_lines=1, label="filename"),
28
+ gr.Text(max_lines=1, label="your_model_id"),
29
+ ],
30
+ outputs=[gr.Markdown(label="output")],
31
+ fn=convert,
32
+ )
33
+
34
+ demo.launch(show_api=True)
convert.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import os
4
+ import shutil
5
+ from pathlib import Path
6
+ from typing import Any
7
+ from tempfile import TemporaryDirectory
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from io import BytesIO
12
+
13
+ from huggingface_hub import CommitInfo, Discussion, HfApi, hf_hub_download
14
+ from huggingface_hub.file_download import repo_folder_name
15
+ import spaces
16
+
17
+
18
+ COMMIT_MESSAGE = " This PR adds weights in safetensors format"
19
+ device = "cpu"
20
+
21
+ def convert_pt_to_safetensors(model_path, safe_path):
22
+ model = torch.load(model_path, map_location="cpu", weights_only=False)
23
+ metadata = {"format":"pt"}
24
+ save_file(model, safe_path, metadata)
25
+
26
+ def convert_single(model_id: str, filename: str, folder: str, progress: Any, token: str):
27
+ progress(0, desc="Downloading model")
28
+
29
+ local_file = os.path.join(model_id, filename)
30
+ local_dir = os.path.dirname(local_file)
31
+ model_path = local_file if os.path.isfile(local_file) else hf_hub_download(repo_id=model_id, filename=filename, token=token, local_dir=local_dir)
32
+ file_path = os.path.splitext(model_path)[0]
33
+
34
+ safe_path = file_path + ".safetensors"
35
+ convert_pt_to_safetensors(model_path, safe_path)
36
+
37
+ return safe_path
38
+
39
+
40
+ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
41
+ try:
42
+ discussions = api.get_repo_discussions(repo_id=model_id)
43
+ except Exception:
44
+ return None
45
+ for discussion in discussions:
46
+ if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
47
+ details = api.get_discussion_details(repo_id=model_id, discussion_num=discussion.num)
48
+ if details.target_branch == "refs/heads/main":
49
+ return discussion
50
+
51
+ @spaces.GPU()
52
+ def convert(token: str, model_id: str, filename: str, your_model_id: str, progress=gr.Progress()):
53
+ api = HfApi()
54
+
55
+ pr_title = "Adding model converted to .safetensors"
56
+
57
+ with TemporaryDirectory() as d:
58
+ folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
59
+ os.makedirs(folder)
60
+ new_pr = None
61
+ try:
62
+ converted_model = convert_single(model_id, filename, folder, progress, token)
63
+ progress(0.7, desc="Uploading to Hub")
64
+ new_pr = api.upload_file(path_or_fileobj=converted_model, path_in_repo=filename, repo_id=your_model_id, repo_type="model", token=token, commit_message=pr_title, commit_description=COMMIT_MESSAGE.format(your_model_id), create_pr=True)
65
+ pr_number = new_pr.split("%2F")[-1].split("/")[0]
66
+ link = f"Pr created at: {'https://huggingface.co/' + os.path.join(your_model_id, 'discussions', pr_number)}"
67
+ progress(1, desc="Done")
68
+ except Exception as e:
69
+ raise gr.exceptions.Error(str(e))
70
+ finally:
71
+ shutil.rmtree(folder)
72
+
73
+ return link
hf_utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import get_hf_file_metadata, hf_hub_url, hf_hub_download, scan_cache_dir, whoami, list_models
2
+
3
+
4
+ def get_my_model_names(token):
5
+
6
+ try:
7
+ author = whoami(token=token)
8
+ model_infos = list_models(author=author["name"], use_auth_token=token)
9
+ return [model.modelId for model in model_infos], None
10
+
11
+ except Exception as e:
12
+ return [], e
13
+
14
+ def download_file(repo_id: str, filename: str, token: str):
15
+ """Download a file from a repo on the Hugging Face Hub.
16
+
17
+ Returns:
18
+ file_path (:obj:`str`): The path to the downloaded file.
19
+ revision (:obj:`str`): The commit hash of the file.
20
+ """
21
+
22
+ md = get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename=filename), token=token)
23
+ revision = md.commit_hash
24
+
25
+ file_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, token=token)
26
+
27
+ return file_path, revision
28
+
29
+ def delete_file(revision: str):
30
+ """Delete a file from local cache.
31
+
32
+ Args:
33
+ revision (:obj:`str`): The commit hash of the file.
34
+ Returns:
35
+ None
36
+ """
37
+ scan_cache_dir().delete_revisions(revision).execute()
38
+
39
+ def get_pr_url(api, repo_id, title):
40
+ try:
41
+ discussions = api.get_repo_discussions(repo_id=repo_id)
42
+ except Exception:
43
+ return None
44
+ for discussion in discussions:
45
+ if (
46
+ discussion.status == "open"
47
+ and discussion.is_pull_request
48
+ and discussion.title == title
49
+ ):
50
+ return f"https://huggingface.co/{repo_id}/discussions/{discussion.num}"
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub[hf_transfer]
2
+ safetensors
3
+ transformers
4
+ accelerate
5
+ omegaconf
6
+ pytorch_lightning
7
+ pyngrok
8
+
utils.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def is_google_colab():
2
+ try:
3
+ import google.colab
4
+ return True
5
+ except:
6
+ return False