File size: 4,387 Bytes
fe371ad
b3aebd1
fe371ad
 
 
 
b3aebd1
 
 
e870084
b3aebd1
e870084
 
b3aebd1
 
17c9dd0
b3aebd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe371ad
 
 
b3aebd1
 
fe371ad
 
 
b3aebd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe371ad
 
 
 
 
b3aebd1
fe371ad
 
 
 
b3aebd1
 
 
 
 
 
 
 
 
 
 
 
 
e870084
 
 
 
 
 
 
fe371ad
b3aebd1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from llama_cpp import Llama
from typing import Dict, List, Any, Union
import os


class EndpointHandler:
    _instance = None  # Singleton instance
    _model_loaded = False  # Flag to check if the model is loaded

    def __new__(cls, *args, **kwargs):
        if not cls._instance:
            cls._instance = super(EndpointHandler, cls).__new__(cls, *args, **kwargs)
            cls._instance._model_loaded = False
        return cls._instance

    def __init__(self, model_path=""):
        if not self._model_loaded:
            # Construct the model path assuming the model is in the same directory as the handler file
            script_dir = os.path.dirname(os.path.abspath(__file__))
            model_filename = "Phi-3-medium-128k-instruct-IQ2_XS.gguf"
            self.model_path = os.path.join(script_dir, model_filename)

            # Check if the model file exists
            if not os.path.exists(self.model_path):
                raise ValueError(f"Model path does not exist: {self.model_path}")

            # Load the GGUF model using llama_cpp
            self.llm = Llama(
                model_path=self.model_path,
                n_ctx=5000,  # Set context length to 5000 tokens
                # n_threads=12,  # Adjust the number of CPU threads as per your machine
                n_gpu_layers=-1  # Adjust based on GPU availability
            )

            # Define generation kwargs for the model
            self.generation_kwargs = {
                "max_tokens": 400,  # Respond with up to 400 tokens
                "stop": ["<|end|>", "<|user|>", "<|assistant|>"],
                "top_k": 1  # Greedy decoding
            }

            self._model_loaded = True

    def __call__(self, data: Union[Dict[str, Any], str]) -> List[Dict[str, Any]]:
        """
        Data args:
            inputs (:obj:`dict`): The input prompts for the LLM including system instructions and user messages.
            str: A string input which will create a chat completion.

        Return:
            A :obj:`list` | `dict`: will be serialized and returned.
        """
        if isinstance(data, dict):
            # Extract inputs
            inputs = data.get("inputs", {})
            system_instructions = inputs.get("system", "")
            user_message = inputs.get("message", "")

            if not user_message:
                raise ValueError("No user message provided for the model.")

            # Combine system instructions and user message
            final_input = f"{system_instructions}\n{user_message}"

            # Run inference with llama_cpp
            response = self.llm.create_chat_completion(
                messages=[
                    {"role": "system", "content": system_instructions},
                    {"role": "user", "content": user_message}
                ],
                **self.generation_kwargs
            )

        elif isinstance(data, str):
            # Create a chat completion from the input string
            response = self.llm.create_chat_completion(
                messages=[
                    {"role": "user", "content": data}
                ],
                **self.generation_kwargs
            )

        else:
            raise ValueError("Invalid input type. Expected dict or str, got {}".format(type(data)))

        # Access generated text based on the response structure
        try:
            generated_text = response["choices"][0]["message"].get("content", "")
        except (KeyError, IndexError):
            raise ValueError("Unexpected response structure: missing 'content' in 'choices[0]['message']'")

        # Return the generated text
        return [{"generated_text": generated_text}]


def main():
    handler = EndpointHandler()  # assume Handler is the class that contains the __call__ method

    # Test 1: Dictionary input
    data_dict = {"inputs": {"system": "System instructions", "message": "Hello, how are you?"}}
    result_dict = handler(data_dict)
    print("Dictionary input result:", result_dict)

    # Test 2: String input
    data_str = "Hello, how are you?"
    result_str = handler(data_str)
    print("String input result:", result_str)

    # Test 3: Invalid input type
    data_invalid = 123
    try:
        handler(data_invalid)
    except ValueError as e:
        print("Invalid input type error:", e)

if __name__ == "__main__":
    main()