silveroxides's picture
Update convert.py
111380e verified
raw
history blame
4.17 kB
import gradio as gr
import requests
import os
import shutil
from pathlib import Path
from typing import Any
from tempfile import TemporaryDirectory
from typing import Optional
import torch
from io import BytesIO
from safetensors.torch import save_file
from huggingface_hub import CommitInfo, Discussion, HfApi, hf_hub_download
from huggingface_hub.file_download import repo_folder_name
import spaces
COMMIT_MESSAGE = " This PR adds weights in safetensors format"
device = "cpu"
def convert_pt_to_safetensors(model_path, safe_path):
# Use weights_only=True for security unless you need to execute code from the pickle
# For model conversion, it's often safer. If it fails, you can revert to False.
try:
model = torch.load(model_path, map_location="cpu", weights_only=True)
except Exception:
# Fallback for models that can't be loaded with weights_only
model = torch.load(model_path, map_location="cpu")
# If the loaded object is a state_dict, use it directly.
# If it's a full model object, get its state_dict.
if not isinstance(model, dict):
model = model.state_dict()
metadata = {"format":"pt"}
save_file(model, safe_path, metadata)
def convert_single(model_id: str, filename: str, folder: str, progress: Any, token: str):
progress(0, desc=f"Downloading {filename}")
local_file = os.path.join(model_id, filename)
local_dir = os.path.dirname(local_file)
model_path = local_file if os.path.isfile(local_file) else hf_hub_download(repo_id=model_id, filename=filename, token=token, local_dir=local_dir)
file_path = os.path.splitext(model_path)[0]
safe_path = file_path + ".safetensors"
convert_pt_to_safetensors(model_path, safe_path)
return safe_path
def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
try:
discussions = api.get_repo_discussions(repo_id=model_id)
except Exception:
return None
for discussion in discussions:
if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
details = api.get_discussion_details(repo_id=model_id, discussion_num=discussion.num)
if details.target_branch == "refs/heads/main":
return discussion
def convert(token: str, model_id: str, filename: str, your_model_id: str, progress=gr.Progress()):
api = HfApi()
pr_title = "Adding model converted to .safetensors"
with TemporaryDirectory() as d:
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
os.makedirs(folder)
new_pr = None
try:
converted_model = convert_single(model_id, filename, folder, progress, token)
# Create the correct filename for the repository (e.g., "model.safetensors")
# os.path.splitext('path/to/model.pth') -> ('path/to/model', '.pth')
# We take the first part and add the new extension.
safetensors_filename = os.path.splitext(filename)[0] + ".safetensors"
progress(0.7, desc=f"Uploading {safetensors_filename} to Hub")
# Use the new, correct filename in `path_in_repo`
new_pr = api.upload_file(
path_or_fileobj=converted_model,
path_in_repo=safetensors_filename, # <-- CORRECTED
repo_id=your_model_id,
repo_type="model",
token=token,
commit_message=pr_title,
commit_description=COMMIT_MESSAGE.format(your_model_id),
create_pr=True
)
pr_number = new_pr.split("%2F")[-1].split("/")[0]
link = f"Pr created at: {'https://huggingface.co/' + os.path.join(your_model_id, 'discussions', pr_number)}"
progress(1, desc="Done")
except Exception as e:
raise gr.exceptions.Error(str(e))
finally:
# This check prevents an error if the folder was already removed or failed to create
if os.path.exists(folder):
shutil.rmtree(folder)
return link