|
from typing import Any, Dict |
|
|
|
import torch |
|
from huggingface_inference_toolkit.logging import logger |
|
from transformers import AutoModelForCausalLM, AutoProcessor |
|
from transformers.image_utils import load_image |
|
|
|
IMAGE_TOKENS = "<image_start><image><image_end>" |
|
SEPARATOR = "\n" |
|
|
|
|
|
class EndpointHandler: |
|
def __init__( |
|
self, |
|
model_dir: str = "alvarobartt/Magma-8B", |
|
**kwargs: Any, |
|
) -> None: |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_dir, trust_remote_code=True, torch_dtype=torch.bfloat16 |
|
).eval() |
|
self.model.to("cuda") |
|
self.processor = AutoProcessor.from_pretrained( |
|
model_dir, trust_remote_code=True |
|
) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Any: |
|
logger.info(f"Received payload with {data}") |
|
if "inputs" not in data: |
|
raise ValueError( |
|
"The request body must contain a key 'inputs' with a list of messages." |
|
) |
|
|
|
logger.info("Processing the messages...") |
|
messages, images = [], [] |
|
for message in data["inputs"]: |
|
logger.info(f"Processing {message=}...") |
|
if isinstance(message["content"], list): |
|
new_message = {"role": message["role"], "content": ""} |
|
for content in message["content"]: |
|
logger.info(f"{content=} is of type {content['type']}") |
|
if content["type"] == "text": |
|
new_message["content"] += content["text"] |
|
elif content["type"] == "image_url": |
|
images.append(load_image(content["image_url"]["url"])) |
|
if new_message["content"].count( |
|
f"{IMAGE_TOKENS}{SEPARATOR}" |
|
) < len(images): |
|
new_message["content"] = ( |
|
f"{IMAGE_TOKENS}{SEPARATOR}" + new_message["content"] |
|
) |
|
messages.append(new_message) |
|
else: |
|
messages.append( |
|
{"role": message["role"], "content": message["content"]} |
|
) |
|
|
|
data.pop("inputs") |
|
|
|
logger.info(f"Applying chat template to {messages=}") |
|
prompt = self.processor.tokenizer.apply_chat_template( |
|
messages, tokenize=False, add_generation_prompt=True |
|
) |
|
|
|
logger.info(f"Processing {len(images)} images...") |
|
inputs = self.processor(images=images, texts=prompt, return_tensors="pt") |
|
inputs["pixel_values"] = inputs["pixel_values"].unsqueeze(0) |
|
inputs["image_sizes"] = inputs["image_sizes"].unsqueeze(0) |
|
inputs = inputs.to("cuda").to(torch.bfloat16) |
|
|
|
generation_args = { |
|
"max_new_tokens": data.get("max_new_tokens", data.get("max_tokens", 128)), |
|
"temperature": data.get("temperature", 0.0), |
|
"do_sample": False, |
|
"use_cache": True, |
|
"num_beams": 1, |
|
} |
|
logger.info( |
|
f"Running text generation with the following {generation_args=} (skipped {set(data.keys()) - set(generation_args.keys())})" |
|
) |
|
|
|
with torch.inference_mode(): |
|
generate_ids = self.model.generate(**inputs, **generation_args) |
|
|
|
generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :] |
|
response = self.processor.decode( |
|
generate_ids[0], skip_special_tokens=True |
|
).strip() |
|
|
|
return {"generated_text": response} |