File size: 3,604 Bytes
e2a8712 6d88bf9 e2a8712 f74a199 e2a8712 6d88bf9 e2a8712 6d88bf9 e2a8712 6d88bf9 e2a8712 6d88bf9 35a687f e2a8712 bab586d e2a8712 7cb72ce e2a8712 6d88bf9 e2a8712 6d88bf9 e2a8712 6d88bf9 e2a8712 cb5f6e9 e2a8712 0aea239 8284ee8 9cf98e4 e2a8712 8284ee8 e2a8712 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
from typing import Any, Dict
import torch
from huggingface_inference_toolkit.logging import logger
from transformers import AutoModelForCausalLM, AutoProcessor
from transformers.image_utils import load_image
IMAGE_TOKENS = "<image_start><image><image_end>"
SEPARATOR = "\n"
class EndpointHandler:
def __init__(
self,
model_dir: str = "alvarobartt/Magma-8B",
**kwargs: Any, # type: ignore
) -> None:
self.model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=True, torch_dtype=torch.bfloat16
).eval()
self.model.to("cuda")
self.processor = AutoProcessor.from_pretrained(
model_dir, trust_remote_code=True
)
def __call__(self, data: Dict[str, Any]) -> Any:
logger.info(f"Received payload with {data}")
if "inputs" not in data:
raise ValueError(
"The request body must contain a key 'inputs' with a list of messages."
)
logger.info("Processing the messages...")
messages, images = [], []
for message in data["inputs"]:
logger.info(f"Processing {message=}...")
if isinstance(message["content"], list):
new_message = {"role": message["role"], "content": ""}
for content in message["content"]:
logger.info(f"{content=} is of type {content['type']}")
if content["type"] == "text":
new_message["content"] += content["text"]
elif content["type"] == "image_url":
images.append(load_image(content["image_url"]["url"]))
if new_message["content"].count(
f"{IMAGE_TOKENS}{SEPARATOR}"
) < len(images):
new_message["content"] = (
f"{IMAGE_TOKENS}{SEPARATOR}" + new_message["content"]
)
messages.append(new_message)
else:
messages.append(
{"role": message["role"], "content": message["content"]}
)
data.pop("inputs")
logger.info(f"Applying chat template to {messages=}")
prompt = self.processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
logger.info(f"Processing {len(images)} images...")
inputs = self.processor(images=images, texts=prompt, return_tensors="pt")
inputs["pixel_values"] = inputs["pixel_values"].unsqueeze(0)
inputs["image_sizes"] = inputs["image_sizes"].unsqueeze(0)
inputs = inputs.to("cuda").to(torch.bfloat16)
generation_args = {
"max_new_tokens": data.get("max_new_tokens", data.get("max_tokens", 128)),
"temperature": data.get("temperature", 0.0),
"do_sample": False, # temperature won't really work unless this is set to True
"use_cache": True,
"num_beams": 1,
}
logger.info(
f"Running text generation with the following {generation_args=} (skipped {set(data.keys()) - set(generation_args.keys())})"
)
with torch.inference_mode():
generate_ids = self.model.generate(**inputs, **generation_args)
generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :]
response = self.processor.decode(
generate_ids[0], skip_special_tokens=True
).strip()
return {"generated_text": response} |