silveroxides commited on
Commit
d5bb79f
·
verified ·
1 Parent(s): 111380e

Update convert.py

Browse files
Files changed (1) hide show
  1. convert.py +24 -33
convert.py CHANGED
@@ -19,24 +19,12 @@ COMMIT_MESSAGE = " This PR adds weights in safetensors format"
19
  device = "cpu"
20
 
21
  def convert_pt_to_safetensors(model_path, safe_path):
22
- # Use weights_only=True for security unless you need to execute code from the pickle
23
- # For model conversion, it's often safer. If it fails, you can revert to False.
24
- try:
25
- model = torch.load(model_path, map_location="cpu", weights_only=True)
26
- except Exception:
27
- # Fallback for models that can't be loaded with weights_only
28
- model = torch.load(model_path, map_location="cpu")
29
-
30
- # If the loaded object is a state_dict, use it directly.
31
- # If it's a full model object, get its state_dict.
32
- if not isinstance(model, dict):
33
- model = model.state_dict()
34
-
35
  metadata = {"format":"pt"}
36
  save_file(model, safe_path, metadata)
37
 
38
  def convert_single(model_id: str, filename: str, folder: str, progress: Any, token: str):
39
- progress(0, desc=f"Downloading {filename}")
40
 
41
  local_file = os.path.join(model_id, filename)
42
  local_dir = os.path.dirname(local_file)
@@ -71,34 +59,37 @@ def convert(token: str, model_id: str, filename: str, your_model_id: str, progre
71
  new_pr = None
72
  try:
73
  converted_model = convert_single(model_id, filename, folder, progress, token)
74
-
75
- # Create the correct filename for the repository (e.g., "model.safetensors")
76
- # os.path.splitext('path/to/model.pth') -> ('path/to/model', '.pth')
77
- # We take the first part and add the new extension.
78
- safetensors_filename = os.path.splitext(filename)[0] + ".safetensors"
79
 
80
- progress(0.7, desc=f"Uploading {safetensors_filename} to Hub")
 
 
 
81
 
82
- # Use the new, correct filename in `path_in_repo`
 
 
83
  new_pr = api.upload_file(
84
- path_or_fileobj=converted_model,
85
- path_in_repo=safetensors_filename, # <-- CORRECTED
86
- repo_id=your_model_id,
87
- repo_type="model",
88
- token=token,
89
- commit_message=pr_title,
90
- commit_description=COMMIT_MESSAGE.format(your_model_id),
91
  create_pr=True
92
  )
93
-
94
  pr_number = new_pr.split("%2F")[-1].split("/")[0]
95
  link = f"Pr created at: {'https://huggingface.co/' + os.path.join(your_model_id, 'discussions', pr_number)}"
96
  progress(1, desc="Done")
97
  except Exception as e:
98
  raise gr.exceptions.Error(str(e))
99
  finally:
100
- # This check prevents an error if the folder was already removed or failed to create
101
- if os.path.exists(folder):
102
- shutil.rmtree(folder)
 
103
 
104
- return link
 
 
 
19
  device = "cpu"
20
 
21
  def convert_pt_to_safetensors(model_path, safe_path):
22
+ model = torch.load(model_path, map_location="cpu", weights_only=False)
 
 
 
 
 
 
 
 
 
 
 
 
23
  metadata = {"format":"pt"}
24
  save_file(model, safe_path, metadata)
25
 
26
  def convert_single(model_id: str, filename: str, folder: str, progress: Any, token: str):
27
+ progress(0, desc="Downloading model")
28
 
29
  local_file = os.path.join(model_id, filename)
30
  local_dir = os.path.dirname(local_file)
 
59
  new_pr = None
60
  try:
61
  converted_model = convert_single(model_id, filename, folder, progress, token)
 
 
 
 
 
62
 
63
+ # --- START OF THE REQUESTED CHANGE ---
64
+ # Create the correct destination filename by replacing the extension
65
+ safetensors_filename = os.path.splitext(filename)[0] + ".safetensors"
66
+ # --- END OF THE REQUESTED CHANGE ---
67
 
68
+ progress(0.7, desc="Uploading to Hub")
69
+
70
+ # Use the corrected filename in path_in_repo
71
  new_pr = api.upload_file(
72
+ path_or_fileobj=converted_model,
73
+ path_in_repo=safetensors_filename, # <-- THE ONLY MODIFIED ARGUMENT
74
+ repo_id=your_model_id,
75
+ repo_type="model",
76
+ token=token,
77
+ commit_message=pr_title,
78
+ commit_description=COMMIT_MESSAGE.format(your_model_id),
79
  create_pr=True
80
  )
81
+
82
  pr_number = new_pr.split("%2F")[-1].split("/")[0]
83
  link = f"Pr created at: {'https://huggingface.co/' + os.path.join(your_model_id, 'discussions', pr_number)}"
84
  progress(1, desc="Done")
85
  except Exception as e:
86
  raise gr.exceptions.Error(str(e))
87
  finally:
88
+ shutil.rmtree(folder)
89
+
90
+ finish_message(link)
91
+ return link
92
 
93
+ @spaces.GPU()
94
+ def finish_message(link: str):
95
+ return print(link)