|
from llama_cpp import Llama |
|
from typing import Dict, List, Any, Union |
|
import os |
|
|
|
|
|
class EndpointHandler: |
|
_instance = None |
|
_model_loaded = False |
|
|
|
def __new__(cls): |
|
if not cls._instance: |
|
cls._instance = super(EndpointHandler, cls).__new__(cls) |
|
return cls._instance |
|
|
|
def __init__(self, model_path=""): |
|
if not self._model_loaded: |
|
if not self._model_loaded: |
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
model_filename = "Phi-3-medium-128k-instruct-IQ2_XS.gguf" |
|
self.model_path = os.path.join(script_dir, model_filename) |
|
|
|
|
|
if not os.path.exists(self.model_path): |
|
raise ValueError(f"Model path does not exist: {self.model_path}") |
|
|
|
|
|
self.llm = Llama( |
|
model_path=self.model_path, |
|
n_ctx=5000, |
|
|
|
n_gpu_layers=-1 |
|
) |
|
|
|
|
|
self.generation_kwargs = { |
|
"max_tokens": 400, |
|
"stop": ["<|end|>", "<|user|>", "<|assistant|>"], |
|
"top_k": 1 |
|
} |
|
|
|
self._model_loaded = True |
|
|
|
def __call__(self, data: Union[Dict[str, Any], str]) -> List[Dict[str, Any]]: |
|
""" |
|
Data args: |
|
inputs (:obj:`dict`): The input prompts for the LLM including system instructions and user messages. |
|
str: A string input which will create a chat completion. |
|
|
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned. |
|
""" |
|
if isinstance(data, dict): |
|
|
|
inputs = data.get("inputs", {}) |
|
system_instructions = inputs.get("system", "") |
|
user_message = inputs.get("message", "") |
|
|
|
if not user_message: |
|
raise ValueError("No user message provided for the model.") |
|
|
|
|
|
final_input = f"{system_instructions}\n{user_message}" |
|
|
|
|
|
response = self.llm.create_chat_completion( |
|
messages=[ |
|
{"role": "system", "content": system_instructions}, |
|
{"role": "user", "content": user_message} |
|
], |
|
**self.generation_kwargs |
|
) |
|
|
|
elif isinstance(data, str): |
|
|
|
response = self.llm.create_chat_completion( |
|
messages=[ |
|
{"role": "user", "content": data} |
|
], |
|
**self.generation_kwargs |
|
) |
|
|
|
else: |
|
raise ValueError("Invalid input type. Expected dict or str, got {}".format(type(data))) |
|
|
|
|
|
try: |
|
generated_text = response["choices"][0]["message"].get("content", "") |
|
except (KeyError, IndexError): |
|
raise ValueError("Unexpected response structure: missing 'content' in 'choices[0]['message']'") |
|
|
|
|
|
return [{"generated_text": generated_text}] |
|
|
|
|
|
def main(): |
|
handler = EndpointHandler() |
|
|
|
|
|
data_dict = {"inputs": {"system": "System instructions", "message": "Hello, how are you?"}} |
|
result_dict = handler(data_dict) |
|
print("Dictionary input result:", result_dict) |
|
|
|
|
|
data_str = "Hello, how are you?" |
|
result_str = handler(data_str) |
|
print("String input result:", result_str) |
|
|
|
if __name__ == "__main__": |
|
main() |