from typing import Any, Dict import torch from huggingface_inference_toolkit.logging import logger from transformers import AutoModelForCausalLM, AutoProcessor from transformers.image_utils import load_image IMAGE_TOKENS = "" SEPARATOR = "\n" class EndpointHandler: def __init__( self, model_dir: str = "alvarobartt/Magma-8B", **kwargs: Any, # type: ignore ) -> None: self.model = AutoModelForCausalLM.from_pretrained( model_dir, trust_remote_code=True, torch_dtype=torch.bfloat16 ).eval() self.model.to("cuda") self.processor = AutoProcessor.from_pretrained( model_dir, trust_remote_code=True ) def __call__(self, data: Dict[str, Any]) -> Any: logger.info(f"Received payload with {data}") if "inputs" not in data: raise ValueError( "The request body must contain a key 'inputs' with a list of messages." ) logger.info("Processing the messages...") messages, images = [], [] for message in data["inputs"]: logger.info(f"Processing {message=}...") if isinstance(message["content"], list): new_message = {"role": message["role"], "content": ""} for content in message["content"]: logger.info(f"{content=} is of type {content['type']}") if content["type"] == "text": new_message["content"] += content["text"] elif content["type"] == "image_url": images.append(load_image(content["image_url"]["url"])) if new_message["content"].count( f"{IMAGE_TOKENS}{SEPARATOR}" ) < len(images): new_message["content"] = ( f"{IMAGE_TOKENS}{SEPARATOR}" + new_message["content"] ) messages.append(new_message) else: messages.append( {"role": message["role"], "content": message["content"]} ) data.pop("inputs") logger.info(f"Applying chat template to {messages=}") prompt = self.processor.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) logger.info(f"Processing {len(images)} images...") inputs = self.processor(images=images, texts=prompt, return_tensors="pt") inputs["pixel_values"] = inputs["pixel_values"].unsqueeze(0) inputs["image_sizes"] = inputs["image_sizes"].unsqueeze(0) inputs = inputs.to("cuda").to(torch.bfloat16) generation_args = { "max_new_tokens": data.get("max_new_tokens", data.get("max_tokens", 128)), "temperature": data.get("temperature", 0.0), "do_sample": False, # temperature won't really work unless this is set to True "use_cache": True, "num_beams": 1, } logger.info( f"Running text generation with the following {generation_args=} (skipped {set(data.keys()) - set(generation_args.keys())})" ) with torch.inference_mode(): generate_ids = self.model.generate(**inputs, **generation_args) generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :] response = self.processor.decode( generate_ids[0], skip_special_tokens=True ).strip() return {"generated_text": response}