File size: 2,359 Bytes
a81cb37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from typing import Any, Dict, List
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

MAX_INPUT_LENGTH = 256
MAX_OUTPUT_LENGTH = 128

class EndpointHandler:
    def __init__(self, model_dir: str = "", **kwargs: Any) -> None:
        """
        Initializes the model and tokenizer when the endpoint starts.
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
        # Assuming you fine-tuned CodeT5+ for a sequence-to-sequence task
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
        self.model.eval() # Set model to evaluation mode
        # You might want to move the model to GPU if available
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Handles incoming inference requests.
        """
        inputs = data.get("inputs")
        if not inputs:
            raise ValueError("No 'inputs' found in the request data.")

        # Ensure inputs are in a list for batch processing, even if single input
        if isinstance(inputs, str):
            inputs = [inputs]

        # Pre-processing
        # Adjust max_length and padding based on your model's training and task
        tokenized_inputs = self.tokenizer(
            inputs,
            max_length=MAX_INPUT_LENGTH,
            padding=True,
            truncation=True,
            return_tensors="pt"
        ).to(self.device)

        # Inference
        with torch.no_grad():
            outputs = self.model.generate(
                tokenized_inputs["input_ids"],
                attention_mask=tokenized_inputs["attention_mask"],
                # Add generation arguments relevant to your task (e.g., max_length, num_beams)
                max_length=MAX_OUTPUT_LENGTH, # Example, adjust as needed
                num_beams=8,    # Example, adjust as needed
                no_repeat_ngram_size=3,
                pad_token_id=self.tokenizer.pad_token_id) # Fixed: Added self. before tokenizer

        # Post-processing
        decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)

        # Format the output as a list of dictionaries
        results = [{"generated_text": text} for text in decoded_outputs]
        return results