vignesh-trustt's picture
Upload handler.py
94eabde
raw
history blame
2.72 kB
from typing import Dict, Any, List
import logging
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftConfig, PeftModel
import torch.cuda
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0]==8 else torch.float16
# LOGGER = logging.getLogger(__name__)
# logging.basicConfig(level=logging.INFO)
# device = "cuda" if torch.cuda.is_available() else "cpu"
class EndpointHandler():
def __init__(self, path=""):
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
path,
return_dict=True,
device_map="auto",
load_in_8bit=True,
torch_dtype=dtype,
trust_remote_code=True,
)
generation_config = model.generation_config
generation_config.max_new_tokens=512
generation_config.temperation = 0
generation_config.num_return_sequences=1
generation_config.pad_token_id = tokenizer.eos_token_id
generation_config.eos_token_id = tokenizer.eos_token_id
self.generation_config = generation_config
self.pipeline = transformers.pipeline(
"text-generation",model=model,tokenizer=tokenizer
)
# config = PeftConfig.from_pretrained(path)
# model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map='auto')
# self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
# # Load the Lora model
# self.model = PeftModel.from_pretrained(model, path)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# """
# Args:
# data (Dict): The payload with the text prompt and generation parameters.
# """
# LOGGER.info(f"Received data: {data}")
# Get inputs
prompt = data.pop("inputs", None)
# parameters = data.pop("parameters", None)
# if prompt is None:
# raise ValueError("Missing prompt.")
# # Preprocess
# input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# # Forward
# LOGGER.info(f"Start generation.")
# if parameters is not None:
# output = self.model.generate(input_ids=input_ids, **parameters)
# else:
# output = self.model.generate(input_ids=input_ids)
# # Postprocess
# prediction = self.tokenizer.decode(output[0])
# LOGGER.info(f"Generated text: {prediction}")
# return {"generated_text": prediction}
result = self.pipeline(prompt,generation_config=self.generation_config)
return result