voxmenthe commited on
Commit
b976908
·
1 Parent(s): cc8ec94

update inference to load from hf model id

Browse files
Files changed (2) hide show
  1. config.yaml +1 -1
  2. inference.py +44 -58
config.yaml CHANGED
@@ -1,6 +1,6 @@
1
  model:
2
  name: "voxmenthe/modernbert-imdb-sentiment"
3
- output_dir: "checkpoints"
4
  max_length: 880 # 256
5
  dropout: 0.1
6
  pooling_strategy: "mean" # Current default, change as needed
 
1
  model:
2
  name: "voxmenthe/modernbert-imdb-sentiment"
3
+ tokenizer_name_or_path: "answerdotai/ModernBERT-base"
4
  max_length: 880 # 256
5
  dropout: 0.1
6
  pooling_strategy: "mean" # Current default, change as needed
inference.py CHANGED
@@ -1,79 +1,65 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
- from models import ModernBertForSentiment
4
- from transformers import ModernBertConfig
5
  from typing import Dict, Any
6
  import yaml
7
- import os
8
-
9
 
10
  class SentimentInference:
11
  def __init__(self, config_path: str = "config.yaml"):
12
- """Load configuration and initialize model and tokenizer."""
13
  with open(config_path, 'r') as f:
14
- config = yaml.safe_load(f)
15
-
16
- model_cfg = config.get('model', {})
17
- inference_cfg = config.get('inference', {})
18
 
19
- # Path to the .pt model weights file
20
- model_weights_path = inference_cfg.get('model_path',
21
- os.path.join(model_cfg.get('output_dir', 'checkpoints'), 'best_model.pt'))
22
 
23
- # Base model name from config (e.g., 'answerdotai/ModernBERT-base')
24
- # This will be used for loading both tokenizer and base BERT config from Hugging Face Hub
25
- base_model_name = model_cfg.get('name', 'answerdotai/ModernBERT-base')
26
-
27
- self.max_length = inference_cfg.get('max_length', model_cfg.get('max_length', 256))
28
 
29
- # Load tokenizer from the base model name (e.g., from Hugging Face Hub)
30
- print(f"Loading tokenizer from: {base_model_name}")
31
- self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
32
-
33
- # Load base BERT config from the base model name
34
- print(f"Loading ModernBertConfig from: {base_model_name}")
35
- bert_config = ModernBertConfig.from_pretrained(base_model_name)
36
-
37
- # --- Apply any necessary overrides from your config to the loaded bert_config ---
38
- # For example, if your ModernBertForSentiment expects specific config values beyond the base BERT model.
39
- # Your current ModernBertForSentiment takes the entire config object, which might implicitly carry these.
40
- # However, explicitly setting them on bert_config loaded from HF is safer if they are architecturally relevant.
41
- bert_config.classifier_dropout = model_cfg.get('dropout', bert_config.classifier_dropout) # Example
42
- # Ensure num_labels is set if your inference model needs it (usually for HF pipeline, less so for manual predict)
43
- # bert_config.num_labels = model_cfg.get('num_labels', 1) # Typically 1 for binary sentiment regression-style output
44
 
45
- # It's also important that pooling_strategy and num_weighted_layers are set on the config object
46
- # that ModernBertForSentiment receives, as it uses these to build its layers.
47
- # These are usually fine-tuning specific, not part of the base HF config, so they should come from your model_cfg.
48
- bert_config.pooling_strategy = model_cfg.get('pooling_strategy', 'cls')
49
- bert_config.num_weighted_layers = model_cfg.get('num_weighted_layers', 4)
50
- bert_config.loss_function = model_cfg.get('loss_function', {'name': 'SentimentWeightedLoss', 'params': {}}) # Needed by model init
51
- # Ensure num_labels is explicitly set for the model's classifier head
52
- bert_config.num_labels = 1 # For sentiment (positive/negative) often treated as 1 logit output
53
 
54
- print("Instantiating ModernBertForSentiment model structure...")
55
- self.model = ModernBertForSentiment(bert_config)
56
 
57
- print(f"Loading model weights from local checkpoint: {model_weights_path}")
58
- # Load the entire checkpoint dictionary first
59
- checkpoint = torch.load(model_weights_path, map_location=torch.device('cpu'))
 
 
60
 
61
- # Extract the model_state_dict from the checkpoint
62
- # This handles the case where the checkpoint saves more than just the model weights (e.g., optimizer state, epoch)
63
- if 'model_state_dict' in checkpoint:
64
- model_state_to_load = checkpoint['model_state_dict']
65
- else:
66
- # If the checkpoint is just the state_dict itself (older format or different saving convention)
67
- model_state_to_load = checkpoint
68
-
69
- self.model.load_state_dict(model_state_to_load)
 
 
 
 
 
 
 
 
 
 
 
 
70
  self.model.eval()
71
- print("Model loaded successfully.")
72
 
73
  def predict(self, text: str) -> Dict[str, Any]:
74
- inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length)
75
  with torch.no_grad():
76
  outputs = self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
77
- logits = outputs["logits"]
 
 
78
  prob = torch.sigmoid(logits).item()
79
  return {"sentiment": "positive" if prob > 0.5 else "negative", "confidence": prob}
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, ModernBertConfig
3
+ # models.py (containing ModernBertForSentiment) will be loaded from the Hub due to trust_remote_code=True
 
4
  from typing import Dict, Any
5
  import yaml
 
 
6
 
7
  class SentimentInference:
8
  def __init__(self, config_path: str = "config.yaml"):
9
+ """Load configuration and initialize model and tokenizer from Hugging Face Hub."""
10
  with open(config_path, 'r') as f:
11
+ config_data = yaml.safe_load(f)
 
 
 
12
 
13
+ model_yaml_cfg = config_data.get('model', {})
14
+ inference_yaml_cfg = config_data.get('inference', {})
 
15
 
16
+ model_hf_repo_id = model_yaml_cfg.get('name_or_path')
17
+ if not model_hf_repo_id:
18
+ raise ValueError("model.name_or_path must be specified in config.yaml (e.g., 'username/model_name')")
 
 
19
 
20
+ tokenizer_hf_repo_id = model_yaml_cfg.get('tokenizer_name_or_path', model_hf_repo_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ self.max_length = inference_yaml_cfg.get('max_length', model_yaml_cfg.get('max_length', 512))
 
 
 
 
 
 
 
23
 
24
+ print(f"Loading tokenizer from: {tokenizer_hf_repo_id}")
25
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_hf_repo_id)
26
 
27
+ print(f"Loading base ModernBertConfig from: {model_hf_repo_id}")
28
+ # Load the config that was uploaded with the model (config.json in the HF repo)
29
+ # This config should already have the correct architecture defined by ModernBertConfig.
30
+ # We then augment it with any custom parameters needed by ModernBertForSentiment's __init__.
31
+ loaded_config = ModernBertConfig.from_pretrained(model_hf_repo_id)
32
 
33
+ # Augment loaded_config with parameters from model_yaml_cfg needed for ModernBertForSentiment initialization
34
+ # These should reflect how the model was trained and its specific custom head.
35
+ loaded_config.pooling_strategy = model_yaml_cfg.get('pooling_strategy', 'mean') # Default to 'mean' as per your models.py change
36
+ loaded_config.num_weighted_layers = model_yaml_cfg.get('num_weighted_layers', 4)
37
+ loaded_config.classifier_dropout = model_yaml_cfg.get('dropout') # Allow None if not in yaml
38
+ # num_labels should ideally be in the config.json uploaded to HF, but can be set here if needed.
39
+ # For binary sentiment with a single logit output, num_labels is 1.
40
+ loaded_config.num_labels = model_yaml_cfg.get('num_labels', 1)
41
+ # The loss_function might not be strictly needed for inference if the model doesn't use it in forward pass for eval,
42
+ # but if ModernBertForSentiment.__init__ requires it, it must be provided.
43
+ # Assuming it's not critical for basic inference here to simplify.
44
+ # loaded_config.loss_function = model_yaml_cfg.get('loss_function', {'name': '...', 'params': {}})
45
+
46
+ print(f"Instantiating and loading model weights for {model_hf_repo_id}...")
47
+ # trust_remote_code=True allows loading models.py (containing ModernBertForSentiment)
48
+ # from the Hugging Face model repository.
49
+ self.model = AutoModelForSequenceClassification.from_pretrained(
50
+ model_hf_repo_id,
51
+ config=loaded_config, # Pass the augmented config
52
+ trust_remote_code=True
53
+ )
54
  self.model.eval()
55
+ print(f"Model {model_hf_repo_id} loaded successfully from Hugging Face Hub.")
56
 
57
  def predict(self, text: str) -> Dict[str, Any]:
58
+ inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True)
59
  with torch.no_grad():
60
  outputs = self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
61
+ logits = outputs.get("logits") # Use .get for safety
62
+ if logits is None:
63
+ raise ValueError("Model output did not contain 'logits'. Check model's forward pass.")
64
  prob = torch.sigmoid(logits).item()
65
  return {"sentiment": "positive" if prob > 0.5 else "negative", "confidence": prob}