Magma-8B / handler.py
alvarobartt's picture
alvarobartt HF Staff
Update handler.py
9cf98e4 verified
from typing import Any, Dict
import torch
from huggingface_inference_toolkit.logging import logger
from transformers import AutoModelForCausalLM, AutoProcessor
from transformers.image_utils import load_image
IMAGE_TOKENS = "<image_start><image><image_end>"
SEPARATOR = "\n"
class EndpointHandler:
def __init__(
self,
model_dir: str = "alvarobartt/Magma-8B",
**kwargs: Any, # type: ignore
) -> None:
self.model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=True, torch_dtype=torch.bfloat16
).eval()
self.model.to("cuda")
self.processor = AutoProcessor.from_pretrained(
model_dir, trust_remote_code=True
)
def __call__(self, data: Dict[str, Any]) -> Any:
logger.info(f"Received payload with {data}")
if "inputs" not in data:
raise ValueError(
"The request body must contain a key 'inputs' with a list of messages."
)
logger.info("Processing the messages...")
messages, images = [], []
for message in data["inputs"]:
logger.info(f"Processing {message=}...")
if isinstance(message["content"], list):
new_message = {"role": message["role"], "content": ""}
for content in message["content"]:
logger.info(f"{content=} is of type {content['type']}")
if content["type"] == "text":
new_message["content"] += content["text"]
elif content["type"] == "image_url":
images.append(load_image(content["image_url"]["url"]))
if new_message["content"].count(
f"{IMAGE_TOKENS}{SEPARATOR}"
) < len(images):
new_message["content"] = (
f"{IMAGE_TOKENS}{SEPARATOR}" + new_message["content"]
)
messages.append(new_message)
else:
messages.append(
{"role": message["role"], "content": message["content"]}
)
data.pop("inputs")
logger.info(f"Applying chat template to {messages=}")
prompt = self.processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
logger.info(f"Processing {len(images)} images...")
inputs = self.processor(images=images, texts=prompt, return_tensors="pt")
inputs["pixel_values"] = inputs["pixel_values"].unsqueeze(0)
inputs["image_sizes"] = inputs["image_sizes"].unsqueeze(0)
inputs = inputs.to("cuda").to(torch.bfloat16)
generation_args = {
"max_new_tokens": data.get("max_new_tokens", data.get("max_tokens", 128)),
"temperature": data.get("temperature", 0.0),
"do_sample": False, # temperature won't really work unless this is set to True
"use_cache": True,
"num_beams": 1,
}
logger.info(
f"Running text generation with the following {generation_args=} (skipped {set(data.keys()) - set(generation_args.keys())})"
)
with torch.inference_mode():
generate_ids = self.model.generate(**inputs, **generation_args)
generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :]
response = self.processor.decode(
generate_ids[0], skip_special_tokens=True
).strip()
return {"generated_text": response}