silveroxides commited on
Commit
6762348
·
verified ·
1 Parent(s): c39bef8

Update convert.py

Browse files
Files changed (1) hide show
  1. convert.py +37 -5
convert.py CHANGED
@@ -19,12 +19,24 @@ COMMIT_MESSAGE = " This PR adds weights in safetensors format"
19
  device = "cpu"
20
 
21
  def convert_pt_to_safetensors(model_path, safe_path):
22
- model = torch.load(model_path, map_location="cpu", weights_only=False)
 
 
 
 
 
 
 
 
 
 
 
 
23
  metadata = {"format":"pt"}
24
  save_file(model, safe_path, metadata)
25
 
26
  def convert_single(model_id: str, filename: str, folder: str, progress: Any, token: str):
27
- progress(0, desc="Downloading model")
28
 
29
  local_file = os.path.join(model_id, filename)
30
  local_dir = os.path.dirname(local_file)
@@ -60,14 +72,34 @@ def convert(token: str, model_id: str, filename: str, your_model_id: str, progre
60
  new_pr = None
61
  try:
62
  converted_model = convert_single(model_id, filename, folder, progress, token)
63
- progress(0.7, desc="Uploading to Hub")
64
- new_pr = api.upload_file(path_or_fileobj=converted_model, path_in_repo=filename, repo_id=your_model_id, repo_type="model", token=token, commit_message=pr_title, commit_description=COMMIT_MESSAGE.format(your_model_id), create_pr=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  pr_number = new_pr.split("%2F")[-1].split("/")[0]
66
  link = f"Pr created at: {'https://huggingface.co/' + os.path.join(your_model_id, 'discussions', pr_number)}"
67
  progress(1, desc="Done")
68
  except Exception as e:
69
  raise gr.exceptions.Error(str(e))
70
  finally:
71
- shutil.rmtree(folder)
 
 
72
 
73
  return link
 
19
  device = "cpu"
20
 
21
  def convert_pt_to_safetensors(model_path, safe_path):
22
+ # Use weights_only=True for security unless you need to execute code from the pickle
23
+ # For model conversion, it's often safer. If it fails, you can revert to False.
24
+ try:
25
+ model = torch.load(model_path, map_location="cpu", weights_only=True)
26
+ except Exception:
27
+ # Fallback for models that can't be loaded with weights_only
28
+ model = torch.load(model_path, map_location="cpu")
29
+
30
+ # If the loaded object is a state_dict, use it directly.
31
+ # If it's a full model object, get its state_dict.
32
+ if not isinstance(model, dict):
33
+ model = model.state_dict()
34
+
35
  metadata = {"format":"pt"}
36
  save_file(model, safe_path, metadata)
37
 
38
  def convert_single(model_id: str, filename: str, folder: str, progress: Any, token: str):
39
+ progress(0, desc=f"Downloading {filename}")
40
 
41
  local_file = os.path.join(model_id, filename)
42
  local_dir = os.path.dirname(local_file)
 
72
  new_pr = None
73
  try:
74
  converted_model = convert_single(model_id, filename, folder, progress, token)
75
+
76
+ # Create the correct filename for the repository (e.g., "model.safetensors")
77
+ # os.path.splitext('path/to/model.pth') -> ('path/to/model', '.pth')
78
+ # We take the first part and add the new extension.
79
+ safetensors_filename = os.path.splitext(filename)[0] + ".safetensors"
80
+
81
+ progress(0.7, desc=f"Uploading {safetensors_filename} to Hub")
82
+
83
+ # Use the new, correct filename in `path_in_repo`
84
+ new_pr = api.upload_file(
85
+ path_or_fileobj=converted_model,
86
+ path_in_repo=safetensors_filename, # <-- CORRECTED
87
+ repo_id=your_model_id,
88
+ repo_type="model",
89
+ token=token,
90
+ commit_message=pr_title,
91
+ commit_description=COMMIT_MESSAGE.format(your_model_id),
92
+ create_pr=True
93
+ )
94
+
95
  pr_number = new_pr.split("%2F")[-1].split("/")[0]
96
  link = f"Pr created at: {'https://huggingface.co/' + os.path.join(your_model_id, 'discussions', pr_number)}"
97
  progress(1, desc="Done")
98
  except Exception as e:
99
  raise gr.exceptions.Error(str(e))
100
  finally:
101
+ # This check prevents an error if the folder was already removed or failed to create
102
+ if os.path.exists(folder):
103
+ shutil.rmtree(folder)
104
 
105
  return link