jallenjia commited on
Commit
8ced43a
·
1 Parent(s): c422602

manual load weight

Browse files
.gitignore CHANGED
@@ -21,3 +21,4 @@ venv/
21
  *.log
22
  web_custom_versions/
23
  .DS_Store
 
 
21
  *.log
22
  web_custom_versions/
23
  .DS_Store
24
+ python_lib/
custom_nodes/comfyui-florence2/nodes.py CHANGED
@@ -128,29 +128,29 @@ class DownloadAndLoadFlorence2Model:
128
 
129
  print(f"Florence2 using {attention} for attention")
130
 
131
- if convert_to_safetensors:
132
- model_weight_path = os.path.join(model_path, 'pytorch_model.bin')
133
- if os.path.exists(model_weight_path):
134
- safetensors_weight_path = os.path.join(model_path, 'model.safetensors')
135
- print(f"Converting {model_weight_path} to {safetensors_weight_path}")
136
- if not os.path.exists(safetensors_weight_path):
137
- sd = torch.load(model_weight_path, map_location=offload_device)
138
- sd_new = {}
139
- for k, v in sd.items():
140
- sd_new[k] = v.clone()
141
- save_file(sd_new, safetensors_weight_path)
142
- if os.path.exists(safetensors_weight_path):
143
- print(f"Conversion successful. Deleting original file: {model_weight_path}")
144
- os.remove(model_weight_path)
145
- print(f"Original {model_weight_path} file deleted.")
146
 
147
  if transformers.__version__ < '4.51.0':
148
  with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement
149
- model = AutoModelForCausalLM.from_pretrained(model_path, attn_implementation=attention, torch_dtype=dtype,trust_remote_code=True, device_map="cpu", low_cpu_mem_usage=False)
 
 
 
150
  else:
151
  from .modeling_florence2 import Florence2ForConditionalGeneration
152
- model = Florence2ForConditionalGeneration.from_pretrained(model_path, attn_implementation=attention, torch_dtype=dtype, device_map="cpu", low_cpu_mem_usage=False)
153
-
 
 
 
 
154
  processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
155
 
156
  if lora is not None:
@@ -231,28 +231,32 @@ class Florence2ModelLoader:
231
  model_path = Florence2ModelLoader.model_paths.get(model)
232
  print(f"Loading model from {model_path}")
233
  print(f"Florence2 using {attention} for attention")
234
- if convert_to_safetensors:
235
- model_weight_path = os.path.join(model_path, 'pytorch_model.bin')
236
- if os.path.exists(model_weight_path):
237
- safetensors_weight_path = os.path.join(model_path, 'model.safetensors')
238
- print(f"Converting {model_weight_path} to {safetensors_weight_path}")
239
- if not os.path.exists(safetensors_weight_path):
240
- sd = torch.load(model_weight_path, map_location=offload_device)
241
- sd_new = {}
242
- for k, v in sd.items():
243
- sd_new[k] = v.clone()
244
- save_file(sd_new, safetensors_weight_path)
245
- if os.path.exists(safetensors_weight_path):
246
- print(f"Conversion successful. Deleting original file: {model_weight_path}")
247
- os.remove(model_weight_path)
248
- print(f"Original {model_weight_path} file deleted.")
249
 
250
  if transformers.__version__ < '4.51.0':
251
  with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement
252
- model = AutoModelForCausalLM.from_pretrained(model_path, attn_implementation=attention, torch_dtype=dtype,trust_remote_code=True, device_map="cpu", low_cpu_mem_usage=False)
 
 
 
253
  else:
254
  from .modeling_florence2 import Florence2ForConditionalGeneration
255
- model = Florence2ForConditionalGeneration.from_pretrained(model_path, attn_implementation=attention, torch_dtype=dtype, device_map="cpu", low_cpu_mem_usage=False)
 
 
 
256
  processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
257
 
258
  if lora is not None:
 
128
 
129
  print(f"Florence2 using {attention} for attention")
130
 
131
+ from transformers import AutoConfig
132
+
133
+ # Manually load the state dict to CPU to avoid issues with ZeroGPU patching
134
+ print("Manually loading weights to CPU...")
135
+ weights_path = os.path.join(model_path, "pytorch_model.bin")
136
+ state_dict = torch.load(weights_path, map_location="cpu")
137
+
138
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
 
 
 
 
 
 
 
139
 
140
  if transformers.__version__ < '4.51.0':
141
  with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement
142
+ model = AutoModelForCausalLM.from_pretrained(
143
+ None, config=config, state_dict=state_dict, attn_implementation=attention,
144
+ torch_dtype=dtype, trust_remote_code=True
145
+ )
146
  else:
147
  from .modeling_florence2 import Florence2ForConditionalGeneration
148
+ model = Florence2ForConditionalGeneration.from_pretrained(
149
+ None, config=config, state_dict=state_dict, attn_implementation=attention, torch_dtype=dtype
150
+ )
151
+
152
+ # We don't need to call .to(offload_device) here as it's already on CPU
153
+ # and the run node will handle moving it to the GPU.
154
  processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
155
 
156
  if lora is not None:
 
231
  model_path = Florence2ModelLoader.model_paths.get(model)
232
  print(f"Loading model from {model_path}")
233
  print(f"Florence2 using {attention} for attention")
234
+
235
+ from transformers import AutoConfig
236
+
237
+ # Manually load the state dict to CPU to avoid issues with ZeroGPU patching
238
+ print("Manually loading weights to CPU...")
239
+ # Prefer safetensors if they exist (potentially after conversion)
240
+ weights_path = os.path.join(model_path, "model.safetensors")
241
+ if not os.path.exists(weights_path):
242
+ weights_path = os.path.join(model_path, "pytorch_model.bin")
243
+
244
+ state_dict = torch.load(weights_path, map_location="cpu")
245
+
246
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
 
 
247
 
248
  if transformers.__version__ < '4.51.0':
249
  with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement
250
+ model = AutoModelForCausalLM.from_pretrained(
251
+ None, config=config, state_dict=state_dict, attn_implementation=attention,
252
+ torch_dtype=dtype, trust_remote_code=True
253
+ )
254
  else:
255
  from .modeling_florence2 import Florence2ForConditionalGeneration
256
+ model = Florence2ForConditionalGeneration.from_pretrained(
257
+ None, config=config, state_dict=state_dict, attn_implementation=attention, torch_dtype=dtype
258
+ )
259
+
260
  processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
261
 
262
  if lora is not None: