voxmenthe commited on
Commit
105a9fa
·
1 Parent(s): ff78fc6

apply fix to config dict passing for inference

Browse files
Files changed (1) hide show
  1. inference.py +33 -21
inference.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, ModernBertConfig
3
  from typing import Dict, Any
4
  import yaml
5
  import os
@@ -84,36 +84,48 @@ class SentimentInference:
84
  # Load from Hugging Face Hub
85
  print(f"[INFERENCE_LOG] Attempting to load model from HUGGING_FACE_HUB: {model_hf_repo_id}") # Logging
86
 
87
- # Here, we use the config that's packaged with the model on the Hub by default.
88
- # We just add/override num_labels, pooling_strategy, num_weighted_layers if they are in our local config.yaml
89
- # as these might be specific to our fine-tuning and not in the Hub's default config.json.
90
- hub_config_overrides = {
91
  "num_labels": model_yaml_cfg.get('num_labels', 1),
92
  "pooling_strategy": model_yaml_cfg.get('pooling_strategy', 'mean'),
93
- "num_weighted_layers": model_yaml_cfg.get('num_weighted_layers', 6) # Default to 6 now
94
  }
95
- print(f"[INFERENCE_LOG] HUB_LOAD: Overrides for Hub config: {hub_config_overrides}") # Logging
96
 
97
  try:
98
- # Using ModernBertForSentiment.from_pretrained directly.
99
- # This assumes the config.json on the Hub for 'model_hf_repo_id' is compatible
100
- # or that from_pretrained can correctly initialize ModernBertForSentiment with it.
 
 
 
 
 
101
  self.model = ModernBertForSentiment.from_pretrained(
102
  model_hf_repo_id,
103
- **hub_config_overrides
104
  )
105
- print(f"[INFERENCE_LOG] HUB_LOAD: Model ModernBertForSentiment loaded successfully from {model_hf_repo_id}.") # Logging
106
  except Exception as e:
107
- print(f"[INFERENCE_LOG] HUB_LOAD: Error loading ModernBertForSentiment from {model_hf_repo_id}: {e}") # Logging
108
  print(f"[INFERENCE_LOG] HUB_LOAD: Falling back to AutoModelForSequenceClassification for {model_hf_repo_id}.") # Logging
109
- # Fallback: Try with AutoModelForSequenceClassification if ModernBertForSentiment fails
110
- # This might happen if the Hub model isn't strictly saved as a ModernBertForSentiment type
111
- # or if its config.json doesn't have _custom_class set, etc.
112
- self.model = AutoModelForSequenceClassification.from_pretrained(
113
- model_hf_repo_id,
114
- **hub_config_overrides
115
- )
116
- print(f"[INFERENCE_LOG] HUB_LOAD: AutoModelForSequenceClassification loaded for {model_hf_repo_id}.") # Logging
 
 
 
 
 
 
 
 
 
 
117
 
118
  self.model.eval()
119
 
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig, ModernBertConfig
3
  from typing import Dict, Any
4
  import yaml
5
  import os
 
84
  # Load from Hugging Face Hub
85
  print(f"[INFERENCE_LOG] Attempting to load model from HUGGING_FACE_HUB: {model_hf_repo_id}") # Logging
86
 
87
+ hub_config_params = {
 
 
 
88
  "num_labels": model_yaml_cfg.get('num_labels', 1),
89
  "pooling_strategy": model_yaml_cfg.get('pooling_strategy', 'mean'),
90
+ "num_weighted_layers": model_yaml_cfg.get('num_weighted_layers', 6)
91
  }
92
+ print(f"[INFERENCE_LOG] HUB_LOAD: Parameters to update Hub config: {hub_config_params}") # Logging
93
 
94
  try:
95
+ # Step 1: Load config from Hub, allowing for our custom ModernBertConfig
96
+ config = ModernBertConfig.from_pretrained(model_hf_repo_id)
97
+ # Step 2: Update the loaded config with our specific parameters
98
+ for key, value in hub_config_params.items():
99
+ setattr(config, key, value)
100
+ print(f"[INFERENCE_LOG] HUB_LOAD: Updated config: {config.to_diff_dict()}")
101
+
102
+ # Step 3: Load model with the updated config
103
  self.model = ModernBertForSentiment.from_pretrained(
104
  model_hf_repo_id,
105
+ config=config
106
  )
107
+ print(f"[INFERENCE_LOG] HUB_LOAD: Model ModernBertForSentiment loaded successfully from {model_hf_repo_id} with updated config.") # Logging
108
  except Exception as e:
109
+ print(f"[INFERENCE_LOG] HUB_LOAD: Error loading ModernBertForSentiment from {model_hf_repo_id} with explicit config: {e}") # Logging
110
  print(f"[INFERENCE_LOG] HUB_LOAD: Falling back to AutoModelForSequenceClassification for {model_hf_repo_id}.") # Logging
111
+
112
+ # Fallback: Try with AutoModelForSequenceClassification
113
+ # Load its config (could be BertConfig or ModernBertConfig if auto-detected)
114
+ # AutoConfig should ideally resolve to ModernBertConfig if architectures field is set in Hub's config.json
115
+ try:
116
+ config_fallback = AutoConfig.from_pretrained(model_hf_repo_id)
117
+ for key, value in hub_config_params.items():
118
+ setattr(config_fallback, key, value)
119
+ print(f"[INFERENCE_LOG] HUB_LOAD_FALLBACK: Updated fallback config: {config_fallback.to_diff_dict()}")
120
+
121
+ self.model = AutoModelForSequenceClassification.from_pretrained(
122
+ model_hf_repo_id,
123
+ config=config_fallback
124
+ )
125
+ print(f"[INFERENCE_LOG] HUB_LOAD_FALLBACK: AutoModelForSequenceClassification loaded for {model_hf_repo_id} with updated config.") # Logging
126
+ except Exception as e_fallback:
127
+ print(f"[INFERENCE_LOG] HUB_LOAD_FALLBACK: Critical error during fallback load: {e_fallback}")
128
+ raise e_fallback # Re-raise if fallback also fails catastrophically
129
 
130
  self.model.eval()
131