import warnings import torch from transformers import AutoModelForCausalLM, AutoTokenizer from typing import Any, Dict class EndpointHandler: INSTRUCTION_KEY = "### Instruction:" RESPONSE_KEY = "### Response:" END_KEY = "### End" INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request." PROMPT_FOR_GENERATION_FORMAT = """{intro} {instruction_key} {instruction} {response_key} """.format( intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction="{instruction}", response_key=RESPONSE_KEY, ) def __init__( self, path, torch_dtype=torch.bfloat16, trust_remote_code=True, ) -> None: self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code ) tokenizer = AutoTokenizer.from_pretrained( "mosaicml/mpt-7b-instruct", trust_remote_code=trust_remote_code ) if tokenizer.pad_token_id is None: warnings.warn( "pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id." ) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" self.tokenizer = tokenizer self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.eval() self.model.to(device=self.device, dtype=torch_dtype) self.generate_kwargs = { "temperature": 0.01, "top_p": 0.92, "top_k": 0, "max_new_tokens": 512, "use_cache": True, "do_sample": True, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, "repetition_penalty": 1.0 } def format_instruction(self, instruction): return self.PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction) def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: # process input inputs = data.pop("inputs", data) parameters = data.pop("parameters", None) # preprocess s = self.format_instruction(instruction=inputs) input_ids = self.tokenizer(s, return_tensors="pt").input_ids.to(self.device) gkw = {**self.generate_kwargs, **parameters} # pass inputs with all kwargs in data with torch.no_grad(): output_ids = self.model.generate(input_ids, **gkw) # Slice the output_ids tensor to get only new tokens new_tokens = output_ids[0, len(input_ids[0]) :] output_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) return [{"generated_text": output_text}]