|
from typing import Any, Dict, List |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
|
|
MAX_INPUT_LENGTH = 256 |
|
MAX_OUTPUT_LENGTH = 128 |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir: str = "", **kwargs: Any) -> None: |
|
""" |
|
Initializes the model and tokenizer when the endpoint starts. |
|
""" |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
|
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) |
|
self.model.eval() |
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.model.to(self.device) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
Handles incoming inference requests. |
|
""" |
|
inputs = data.get("inputs") |
|
if not inputs: |
|
raise ValueError("No 'inputs' found in the request data.") |
|
|
|
|
|
if isinstance(inputs, str): |
|
inputs = [inputs] |
|
|
|
|
|
|
|
tokenized_inputs = self.tokenizer( |
|
inputs, |
|
max_length=MAX_INPUT_LENGTH, |
|
padding=True, |
|
truncation=True, |
|
return_tensors="pt" |
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate( |
|
tokenized_inputs["input_ids"], |
|
attention_mask=tokenized_inputs["attention_mask"], |
|
|
|
max_length=MAX_OUTPUT_LENGTH, |
|
num_beams=8, |
|
no_repeat_ngram_size=3, |
|
pad_token_id=self.tokenizer.pad_token_id) |
|
|
|
|
|
decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
|
|
results = [{"generated_text": text} for text in decoded_outputs] |
|
return results |
|
|