File size: 3,604 Bytes
e2a8712
 
 
6d88bf9
e2a8712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f74a199
e2a8712
 
 
 
 
6d88bf9
 
e2a8712
6d88bf9
e2a8712
 
6d88bf9
e2a8712
6d88bf9
 
35a687f
e2a8712
 
bab586d
e2a8712
 
 
 
 
 
 
 
 
 
7cb72ce
e2a8712
 
 
 
 
6d88bf9
e2a8712
6d88bf9
e2a8712
 
 
 
6d88bf9
e2a8712
 
 
cb5f6e9
e2a8712
 
0aea239
8284ee8
 
9cf98e4
e2a8712
 
8284ee8
 
 
e2a8712
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from typing import Any, Dict

import torch
from huggingface_inference_toolkit.logging import logger
from transformers import AutoModelForCausalLM, AutoProcessor
from transformers.image_utils import load_image

IMAGE_TOKENS = "<image_start><image><image_end>"
SEPARATOR = "\n"


class EndpointHandler:
    def __init__(
        self,
        model_dir: str = "alvarobartt/Magma-8B",
        **kwargs: Any,  # type: ignore
    ) -> None:
        self.model = AutoModelForCausalLM.from_pretrained(
            model_dir, trust_remote_code=True, torch_dtype=torch.bfloat16
        ).eval()
        self.model.to("cuda")
        self.processor = AutoProcessor.from_pretrained(
            model_dir, trust_remote_code=True
        )

    def __call__(self, data: Dict[str, Any]) -> Any:
        logger.info(f"Received payload with {data}")
        if "inputs" not in data:
            raise ValueError(
                "The request body must contain a key 'inputs' with a list of messages."
            )

        logger.info("Processing the messages...")
        messages, images = [], []
        for message in data["inputs"]:
            logger.info(f"Processing {message=}...")
            if isinstance(message["content"], list):
                new_message = {"role": message["role"], "content": ""}
                for content in message["content"]:
                    logger.info(f"{content=} is of type {content['type']}")
                    if content["type"] == "text":
                        new_message["content"] += content["text"]
                    elif content["type"] == "image_url":
                        images.append(load_image(content["image_url"]["url"]))
                        if new_message["content"].count(
                            f"{IMAGE_TOKENS}{SEPARATOR}"
                        ) < len(images):
                            new_message["content"] = (
                                f"{IMAGE_TOKENS}{SEPARATOR}" + new_message["content"]
                            )
                messages.append(new_message)
            else:
                messages.append(
                    {"role": message["role"], "content": message["content"]}
                )

        data.pop("inputs")

        logger.info(f"Applying chat template to {messages=}")
        prompt = self.processor.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )

        logger.info(f"Processing {len(images)} images...")
        inputs = self.processor(images=images, texts=prompt, return_tensors="pt")
        inputs["pixel_values"] = inputs["pixel_values"].unsqueeze(0)
        inputs["image_sizes"] = inputs["image_sizes"].unsqueeze(0)
        inputs = inputs.to("cuda").to(torch.bfloat16)

        generation_args = {
            "max_new_tokens": data.get("max_new_tokens", data.get("max_tokens", 128)),
            "temperature": data.get("temperature", 0.0),
            "do_sample": False,  # temperature won't really work unless this is set to True
            "use_cache": True,
            "num_beams": 1,
        }
        logger.info(
            f"Running text generation with the following {generation_args=} (skipped {set(data.keys()) - set(generation_args.keys())})"
        )

        with torch.inference_mode():
            generate_ids = self.model.generate(**inputs, **generation_args)

        generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :]
        response = self.processor.decode(
            generate_ids[0], skip_special_tokens=True
        ).strip()

        return {"generated_text": response}