gguf-inference / handler.py
syberWolf's picture
updates to handler again
a60a771
raw
history blame
4.17 kB
from llama_cpp import Llama
from typing import Dict, List, Any, Union
import os
class EndpointHandler:
_instance = None # Singleton instance
_model_loaded = False # Flag to check if the model is loaded
def __new__(cls):
if not cls._instance:
cls._instance = super(EndpointHandler, cls).__new__(cls)
return cls._instance
def __init__(self, model_path=""):
if not self._model_loaded:
if not self._model_loaded:
# Construct the model path assuming the model is in the same directory as the handler file
script_dir = os.path.dirname(os.path.abspath(__file__))
model_filename = "Phi-3-medium-128k-instruct-IQ2_XS.gguf"
self.model_path = os.path.join(script_dir, model_filename)
# Check if the model file exists
if not os.path.exists(self.model_path):
raise ValueError(f"Model path does not exist: {self.model_path}")
# Load the GGUF model using llama_cpp
self.llm = Llama(
model_path=self.model_path,
n_ctx=5000, # Set context length to 5000 tokens
# n_threads=12, # Adjust the number of CPU threads as per your machine
n_gpu_layers=-1 # Adjust based on GPU availability
)
# Define generation kwargs for the model
self.generation_kwargs = {
"max_tokens": 400, # Respond with up to 400 tokens
"stop": ["<|end|>", "<|user|>", "<|assistant|>"],
"top_k": 1 # Greedy decoding
}
self._model_loaded = True
def __call__(self, data: Union[Dict[str, Any], str]) -> List[Dict[str, Any]]:
"""
Data args:
inputs (:obj:`dict`): The input prompts for the LLM including system instructions and user messages.
str: A string input which will create a chat completion.
Return:
A :obj:`list` | `dict`: will be serialized and returned.
"""
if isinstance(data, dict):
# Extract inputs
inputs = data.get("inputs", {})
system_instructions = inputs.get("system", "")
user_message = inputs.get("message", "")
if not user_message:
raise ValueError("No user message provided for the model.")
# Combine system instructions and user message
final_input = f"{system_instructions}\n{user_message}"
# Run inference with llama_cpp
response = self.llm.create_chat_completion(
messages=[
{"role": "system", "content": system_instructions},
{"role": "user", "content": user_message}
],
**self.generation_kwargs
)
elif isinstance(data, str):
# Create a chat completion from the input string
response = self.llm.create_chat_completion(
messages=[
{"role": "user", "content": data}
],
**self.generation_kwargs
)
else:
raise ValueError("Invalid input type. Expected dict or str, got {}".format(type(data)))
# Access generated text based on the response structure
try:
generated_text = response["choices"][0]["message"].get("content", "")
except (KeyError, IndexError):
raise ValueError("Unexpected response structure: missing 'content' in 'choices[0]['message']'")
# Return the generated text
return [{"generated_text": generated_text}]
def main():
handler = EndpointHandler() # assume Handler is the class that contains the __call__ method
# Test 1: Dictionary input
data_dict = {"inputs": {"system": "System instructions", "message": "Hello, how are you?"}}
result_dict = handler(data_dict)
print("Dictionary input result:", result_dict)
# Test 2: String input
data_str = "Hello, how are you?"
result_str = handler(data_str)
print("String input result:", result_str)
if __name__ == "__main__":
main()