File size: 3,526 Bytes
b970c25
 
 
 
 
 
 
 
 
 
 
c39bef8
b970c25
 
 
 
 
 
 
 
 
d5bb79f
b970c25
 
 
 
d5bb79f
b970c25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6762348
d5bb79f
 
 
 
6762348
d5bb79f
 
 
6762348
d5bb79f
 
 
 
 
 
 
6762348
 
d5bb79f
b970c25
 
 
 
 
 
d5bb79f
 
 
 
b970c25
d5bb79f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import requests
import os
import shutil
from pathlib import Path
from typing import Any
from tempfile import TemporaryDirectory
from typing import Optional

import torch
from io import BytesIO
from safetensors.torch import save_file
from huggingface_hub import CommitInfo, Discussion, HfApi, hf_hub_download
from huggingface_hub.file_download import repo_folder_name
import spaces


COMMIT_MESSAGE = " This PR adds weights in safetensors format"
device = "cpu"

def convert_pt_to_safetensors(model_path, safe_path):
    model = torch.load(model_path, map_location="cpu", weights_only=False)
    metadata = {"format":"pt"}
    save_file(model, safe_path, metadata)

def convert_single(model_id: str, filename: str, folder: str, progress: Any, token: str):
    progress(0, desc="Downloading model")

    local_file = os.path.join(model_id, filename)
    local_dir = os.path.dirname(local_file)
    model_path = local_file if os.path.isfile(local_file) else hf_hub_download(repo_id=model_id, filename=filename, token=token, local_dir=local_dir)
    file_path = os.path.splitext(model_path)[0]

    safe_path = file_path + ".safetensors"
    convert_pt_to_safetensors(model_path, safe_path)

    return safe_path


def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
    try:
        discussions = api.get_repo_discussions(repo_id=model_id)
    except Exception:
        return None
    for discussion in discussions:
        if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
            details = api.get_discussion_details(repo_id=model_id, discussion_num=discussion.num)
            if details.target_branch == "refs/heads/main":
                return discussion

def convert(token: str, model_id: str, filename: str, your_model_id: str, progress=gr.Progress()):
    api = HfApi()

    pr_title = "Adding model converted to .safetensors"

    with TemporaryDirectory() as d:
        folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
        os.makedirs(folder)
        new_pr = None
        try:
            converted_model = convert_single(model_id, filename, folder, progress, token)
            
            # --- START OF THE REQUESTED CHANGE ---
            # Create the correct destination filename by replacing the extension
            safetensors_filename = os.path.splitext(filename)[0] + ".safetensors"
            # --- END OF THE REQUESTED CHANGE ---

            progress(0.7, desc="Uploading to Hub")
            
            # Use the corrected filename in path_in_repo
            new_pr  = api.upload_file(
                path_or_fileobj=converted_model, 
                path_in_repo=safetensors_filename, # <-- THE ONLY MODIFIED ARGUMENT
                repo_id=your_model_id, 
                repo_type="model", 
                token=token, 
                commit_message=pr_title, 
                commit_description=COMMIT_MESSAGE.format(your_model_id), 
                create_pr=True
            )

            pr_number = new_pr.split("%2F")[-1].split("/")[0]
            link = f"Pr created at: {'https://huggingface.co/' + os.path.join(your_model_id, 'discussions', pr_number)}"
            progress(1, desc="Done")
        except Exception as e:
            raise gr.exceptions.Error(str(e))
        finally:
            shutil.rmtree(folder)

        finish_message(link)
        return link

@spaces.GPU()
def finish_message(link: str):
    return print(link)