velocity-ai commited on
Commit
c410e3a
·
verified ·
1 Parent(s): 4f1bef3

Update code/inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +12 -52
code/inference.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import json
3
  import torch
4
- from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
5
  import logging
6
 
7
  logger = logging.getLogger(__name__)
@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
11
  # Can specify GPU device with:
12
  # CUDA_VISIBLE_DEVICES="1" python script.py
13
 
14
- def model_fn(model_dir):
15
  """Load the model for inference"""
16
  try:
17
  model_id = os.getenv("HF_MODEL_ID")
@@ -19,22 +19,14 @@ def model_fn(model_dir):
19
  # Set specific GPU device if available
20
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
21
  if device.type == 'cuda':
22
- torch.cuda.set_device(device)
23
  torch.cuda.empty_cache()
24
  logger.info(f"Using device: {device}")
25
 
26
- # Load tokenizer
27
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
28
-
29
- # Load config
30
- config = AutoConfig.from_pretrained(model_id,
31
- num_labels=2,
32
- trust_remote_code=True)
33
-
34
- # Load model with sequence classification head
35
  model = AutoModelForSequenceClassification.from_pretrained(
36
  model_id,
37
- config=config,
38
  torch_dtype=torch.bfloat16 if device.type == 'cuda' else torch.float32,
39
  trust_remote_code=True
40
  )
@@ -68,27 +60,15 @@ def predict_fn(data, model_dict):
68
  tokenizer = model_dict["tokenizer"]
69
  device = model_dict["device"]
70
 
71
- logger.info(f"Model is on device: {device}")
72
-
73
- # Parse input and format it like training data
74
  if isinstance(data, str):
75
  input_text = data
76
  elif isinstance(data, dict):
77
- # Extract address components
78
- addr1 = data.get('order_address1', data.get('address_line_1', ''))
79
- addr2 = data.get('order_address2', data.get('address_line_2', ''))
80
- city = data.get('order_city', data.get('city', ''))
81
- state = data.get('order_state', data.get('state', ''))
82
- pincode = str(data.get('order_pincode', data.get('pincode', '')))
83
-
84
- # Format exactly like training data
85
- input_text = f"Address_line_1: {addr1} Address_line_2: {addr2} City: {city} State: {state} Pincode: {pincode}"
86
  else:
87
  input_text = str(data)
88
 
89
- logger.debug(f"Parsed input text: {input_text}")
90
-
91
- # Create tensors directly on target device
92
  inputs = tokenizer(
93
  input_text,
94
  add_special_tokens=True,
@@ -99,43 +79,23 @@ def predict_fn(data, model_dict):
99
  )
100
 
101
  # Move inputs to device
102
- if device.type == 'cuda':
103
- inputs = {k: v.cuda() for k, v in inputs.items()}
104
-
105
- logger.debug(f"Inputs moved to device: {device}")
106
-
107
- # Log tensor devices and dtypes
108
- for k, v in inputs.items():
109
- logger.debug(f"Input '{k}' - Device: {v.device}, Shape: {v.shape}, Dtype: {v.dtype}")
110
 
111
  # Generate prediction
112
- logger.info("Generating prediction")
113
  with torch.no_grad():
114
  if device.type == 'cuda':
115
  torch.cuda.empty_cache()
116
 
117
- try:
118
- # Run inference
119
- outputs = model(**inputs)
120
- # Convert to float32 before softmax to ensure compatibility
121
- logits = outputs.logits.to(dtype=torch.float32)
122
- predictions = torch.softmax(logits, dim=1)
123
-
124
- except RuntimeError as e:
125
- logger.error("Error during inference:")
126
- logger.error(f"Model device: {next(model.parameters()).device}")
127
- logger.error(f"Input devices: {[f'{k}: {v.device}' for k, v in inputs.items()]}")
128
- raise
129
 
130
- # Move predictions to CPU and ensure float32
131
- predictions = predictions.cpu().float().numpy()
132
 
133
  return predictions
134
 
135
  except Exception as e:
136
  logger.error(f"Error during prediction: {str(e)}")
137
- logger.error(f"Model device: {next(model.parameters()).device}")
138
- logger.error(f"Input devices: {[f'{k}: {v.device}' for k, v in inputs.items()]}")
139
  raise
140
 
141
  def input_fn(request_body, request_content_type):
 
1
  import os
2
  import json
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  import logging
6
 
7
  logger = logging.getLogger(__name__)
 
11
  # Can specify GPU device with:
12
  # CUDA_VISIBLE_DEVICES="1" python script.py
13
 
14
+ def model_fn(model_dir, context=None):
15
  """Load the model for inference"""
16
  try:
17
  model_id = os.getenv("HF_MODEL_ID")
 
19
  # Set specific GPU device if available
20
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
21
  if device.type == 'cuda':
 
22
  torch.cuda.empty_cache()
23
  logger.info(f"Using device: {device}")
24
 
25
+ # Load tokenizer and model directly using AutoModelForSequenceClassification
26
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
 
 
 
 
 
 
 
27
  model = AutoModelForSequenceClassification.from_pretrained(
28
  model_id,
29
+ num_labels=2,
30
  torch_dtype=torch.bfloat16 if device.type == 'cuda' else torch.float32,
31
  trust_remote_code=True
32
  )
 
60
  tokenizer = model_dict["tokenizer"]
61
  device = model_dict["device"]
62
 
63
+ # Parse input
 
 
64
  if isinstance(data, str):
65
  input_text = data
66
  elif isinstance(data, dict):
67
+ input_text = data.get("inputs", data.get("text", str(data)))
 
 
 
 
 
 
 
 
68
  else:
69
  input_text = str(data)
70
 
71
+ # Tokenize input
 
 
72
  inputs = tokenizer(
73
  input_text,
74
  add_special_tokens=True,
 
79
  )
80
 
81
  # Move inputs to device
82
+ inputs = {k: v.to(device) for k, v in inputs.items()}
 
 
 
 
 
 
 
83
 
84
  # Generate prediction
 
85
  with torch.no_grad():
86
  if device.type == 'cuda':
87
  torch.cuda.empty_cache()
88
 
89
+ outputs = model(**inputs)
90
+ predictions = torch.softmax(outputs.logits, dim=1)
 
 
 
 
 
 
 
 
 
 
91
 
92
+ # Move predictions to CPU and convert to numpy
93
+ predictions = predictions.cpu().numpy()
94
 
95
  return predictions
96
 
97
  except Exception as e:
98
  logger.error(f"Error during prediction: {str(e)}")
 
 
99
  raise
100
 
101
  def input_fn(request_body, request_content_type):