Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
apply fix to config dict passing for inference
Browse files- inference.py +33 -21
inference.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import torch
|
2 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification, ModernBertConfig
|
3 |
from typing import Dict, Any
|
4 |
import yaml
|
5 |
import os
|
@@ -84,36 +84,48 @@ class SentimentInference:
|
|
84 |
# Load from Hugging Face Hub
|
85 |
print(f"[INFERENCE_LOG] Attempting to load model from HUGGING_FACE_HUB: {model_hf_repo_id}") # Logging
|
86 |
|
87 |
-
|
88 |
-
# We just add/override num_labels, pooling_strategy, num_weighted_layers if they are in our local config.yaml
|
89 |
-
# as these might be specific to our fine-tuning and not in the Hub's default config.json.
|
90 |
-
hub_config_overrides = {
|
91 |
"num_labels": model_yaml_cfg.get('num_labels', 1),
|
92 |
"pooling_strategy": model_yaml_cfg.get('pooling_strategy', 'mean'),
|
93 |
-
"num_weighted_layers": model_yaml_cfg.get('num_weighted_layers', 6)
|
94 |
}
|
95 |
-
print(f"[INFERENCE_LOG] HUB_LOAD:
|
96 |
|
97 |
try:
|
98 |
-
#
|
99 |
-
|
100 |
-
#
|
|
|
|
|
|
|
|
|
|
|
101 |
self.model = ModernBertForSentiment.from_pretrained(
|
102 |
model_hf_repo_id,
|
103 |
-
|
104 |
)
|
105 |
-
print(f"[INFERENCE_LOG] HUB_LOAD: Model ModernBertForSentiment loaded successfully from {model_hf_repo_id}.") # Logging
|
106 |
except Exception as e:
|
107 |
-
print(f"[INFERENCE_LOG] HUB_LOAD: Error loading ModernBertForSentiment from {model_hf_repo_id}: {e}") # Logging
|
108 |
print(f"[INFERENCE_LOG] HUB_LOAD: Falling back to AutoModelForSequenceClassification for {model_hf_repo_id}.") # Logging
|
109 |
-
|
110 |
-
#
|
111 |
-
#
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
self.model.eval()
|
119 |
|
|
|
1 |
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig, ModernBertConfig
|
3 |
from typing import Dict, Any
|
4 |
import yaml
|
5 |
import os
|
|
|
84 |
# Load from Hugging Face Hub
|
85 |
print(f"[INFERENCE_LOG] Attempting to load model from HUGGING_FACE_HUB: {model_hf_repo_id}") # Logging
|
86 |
|
87 |
+
hub_config_params = {
|
|
|
|
|
|
|
88 |
"num_labels": model_yaml_cfg.get('num_labels', 1),
|
89 |
"pooling_strategy": model_yaml_cfg.get('pooling_strategy', 'mean'),
|
90 |
+
"num_weighted_layers": model_yaml_cfg.get('num_weighted_layers', 6)
|
91 |
}
|
92 |
+
print(f"[INFERENCE_LOG] HUB_LOAD: Parameters to update Hub config: {hub_config_params}") # Logging
|
93 |
|
94 |
try:
|
95 |
+
# Step 1: Load config from Hub, allowing for our custom ModernBertConfig
|
96 |
+
config = ModernBertConfig.from_pretrained(model_hf_repo_id)
|
97 |
+
# Step 2: Update the loaded config with our specific parameters
|
98 |
+
for key, value in hub_config_params.items():
|
99 |
+
setattr(config, key, value)
|
100 |
+
print(f"[INFERENCE_LOG] HUB_LOAD: Updated config: {config.to_diff_dict()}")
|
101 |
+
|
102 |
+
# Step 3: Load model with the updated config
|
103 |
self.model = ModernBertForSentiment.from_pretrained(
|
104 |
model_hf_repo_id,
|
105 |
+
config=config
|
106 |
)
|
107 |
+
print(f"[INFERENCE_LOG] HUB_LOAD: Model ModernBertForSentiment loaded successfully from {model_hf_repo_id} with updated config.") # Logging
|
108 |
except Exception as e:
|
109 |
+
print(f"[INFERENCE_LOG] HUB_LOAD: Error loading ModernBertForSentiment from {model_hf_repo_id} with explicit config: {e}") # Logging
|
110 |
print(f"[INFERENCE_LOG] HUB_LOAD: Falling back to AutoModelForSequenceClassification for {model_hf_repo_id}.") # Logging
|
111 |
+
|
112 |
+
# Fallback: Try with AutoModelForSequenceClassification
|
113 |
+
# Load its config (could be BertConfig or ModernBertConfig if auto-detected)
|
114 |
+
# AutoConfig should ideally resolve to ModernBertConfig if architectures field is set in Hub's config.json
|
115 |
+
try:
|
116 |
+
config_fallback = AutoConfig.from_pretrained(model_hf_repo_id)
|
117 |
+
for key, value in hub_config_params.items():
|
118 |
+
setattr(config_fallback, key, value)
|
119 |
+
print(f"[INFERENCE_LOG] HUB_LOAD_FALLBACK: Updated fallback config: {config_fallback.to_diff_dict()}")
|
120 |
+
|
121 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
122 |
+
model_hf_repo_id,
|
123 |
+
config=config_fallback
|
124 |
+
)
|
125 |
+
print(f"[INFERENCE_LOG] HUB_LOAD_FALLBACK: AutoModelForSequenceClassification loaded for {model_hf_repo_id} with updated config.") # Logging
|
126 |
+
except Exception as e_fallback:
|
127 |
+
print(f"[INFERENCE_LOG] HUB_LOAD_FALLBACK: Critical error during fallback load: {e_fallback}")
|
128 |
+
raise e_fallback # Re-raise if fallback also fails catastrophically
|
129 |
|
130 |
self.model.eval()
|
131 |
|