from typing import Dict, Any, List | |
import logging | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from peft import PeftConfig, PeftModel | |
import torch.cuda | |
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0]==8 else torch.float16 | |
# LOGGER = logging.getLogger(__name__) | |
# logging.basicConfig(level=logging.INFO) | |
# device = "cuda" if torch.cuda.is_available() else "cpu" | |
class EndpointHandler(): | |
def __init__(self, path=""): | |
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
path, | |
return_dict=True, | |
device_map="auto", | |
load_in_8bit=True, | |
torch_dtype=dtype, | |
trust_remote_code=True, | |
) | |
generation_config = model.generation_config | |
generation_config.max_new_tokens=512 | |
generation_config.temperation = 0 | |
generation_config.num_return_sequences=1 | |
generation_config.pad_token_id = tokenizer.eos_token_id | |
generation_config.eos_token_id = tokenizer.eos_token_id | |
self.generation_config = generation_config | |
self.pipeline = transformers.pipeline( | |
"text-generation",model=model,tokenizer=tokenizer | |
) | |
# config = PeftConfig.from_pretrained(path) | |
# model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map='auto') | |
# self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) | |
# # Load the Lora model | |
# self.model = PeftModel.from_pretrained(model, path) | |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
# """ | |
# Args: | |
# data (Dict): The payload with the text prompt and generation parameters. | |
# """ | |
# LOGGER.info(f"Received data: {data}") | |
# Get inputs | |
prompt = data.pop("inputs", None) | |
# parameters = data.pop("parameters", None) | |
# if prompt is None: | |
# raise ValueError("Missing prompt.") | |
# # Preprocess | |
# input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
# # Forward | |
# LOGGER.info(f"Start generation.") | |
# if parameters is not None: | |
# output = self.model.generate(input_ids=input_ids, **parameters) | |
# else: | |
# output = self.model.generate(input_ids=input_ids) | |
# # Postprocess | |
# prediction = self.tokenizer.decode(output[0]) | |
# LOGGER.info(f"Generated text: {prediction}") | |
# return {"generated_text": prediction} | |
result = self.pipeline(prompt,generation_config=self.generation_config) | |
return result | |