|
import os |
|
import json |
|
import torch |
|
import torch.nn as nn |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
class PhiForSequenceClassification(nn.Module): |
|
def __init__(self, base_model, num_labels=2): |
|
super().__init__() |
|
self.phi = base_model |
|
|
|
dtype = next(base_model.parameters()).dtype |
|
self.classifier = nn.Linear(self.phi.config.hidden_size, num_labels, dtype=dtype) |
|
|
|
def forward(self, **inputs): |
|
outputs = self.phi(**inputs, output_hidden_states=True) |
|
|
|
last_hidden_state = outputs.hidden_states[-1][:, -1, :] |
|
logits = self.classifier(last_hidden_state) |
|
return type('Outputs', (), {'logits': logits})() |
|
|
|
def model_fn(model_dir, context=None): |
|
"""Load the model for inference""" |
|
try: |
|
model_id = os.getenv("HF_MODEL_ID") |
|
|
|
|
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
if device.type == 'cuda': |
|
torch.cuda.empty_cache() |
|
logger.info(f"Using device: {device}") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
|
|
|
|
|
config = AutoConfig.from_pretrained(model_id, |
|
trust_remote_code=True) |
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
config=config, |
|
torch_dtype=torch.bfloat16 if device.type == 'cuda' else torch.float32, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
model = PhiForSequenceClassification(base_model, num_labels=2) |
|
|
|
|
|
model = model.to(device) |
|
|
|
|
|
if device.type == 'cuda': |
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
model.eval() |
|
|
|
logger.info(f"Model loaded successfully on {device}") |
|
|
|
return { |
|
"model": model, |
|
"tokenizer": tokenizer, |
|
"device": device |
|
} |
|
except Exception as e: |
|
logger.error(f"Error loading model: {str(e)}") |
|
raise |
|
|
|
def predict_fn(data, model_dict): |
|
"""Make a prediction""" |
|
try: |
|
logger.info("Starting prediction") |
|
model = model_dict["model"] |
|
tokenizer = model_dict["tokenizer"] |
|
device = model_dict["device"] |
|
|
|
|
|
if isinstance(data, str): |
|
input_text = data |
|
elif isinstance(data, dict): |
|
input_text = data.get("inputs", data.get("text", str(data))) |
|
else: |
|
input_text = str(data) |
|
|
|
|
|
inputs = tokenizer( |
|
input_text, |
|
add_special_tokens=True, |
|
max_length=128, |
|
padding='max_length', |
|
truncation=True, |
|
return_tensors='pt' |
|
) |
|
|
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
if device.type == 'cuda': |
|
torch.cuda.empty_cache() |
|
|
|
outputs = model(**inputs) |
|
predictions = torch.softmax(outputs.logits, dim=1) |
|
|
|
|
|
predictions = predictions.cpu().numpy() |
|
|
|
return predictions |
|
|
|
except Exception as e: |
|
logger.error(f"Error during prediction: {str(e)}") |
|
raise |
|
|
|
def input_fn(request_body, request_content_type): |
|
"""Parse input request""" |
|
if request_content_type == "application/json": |
|
try: |
|
data = json.loads(request_body) |
|
except: |
|
data = request_body |
|
return data |
|
else: |
|
return request_body |
|
|
|
def output_fn(prediction, response_content_type): |
|
"""Format the output""" |
|
if response_content_type == "application/json": |
|
return json.dumps(prediction.tolist()) |
|
else: |
|
raise ValueError(f"Unsupported content type: {response_content_type}") |