Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
SPDX-License-Identifier: MIT | |
""" | |
import json | |
import os | |
from typing import Optional | |
import tensorrt_llm | |
import tensorrt_llm.profiler as profiler | |
import torch | |
from PIL import Image | |
from pydantic import BaseModel, Field | |
from tensorrt_llm import logger | |
from tensorrt_llm import mpi_rank | |
from tensorrt_llm.runtime import MultimodalModelRunner | |
from transformers import AutoTokenizer, DonutProcessor | |
class InferenceConfig(BaseModel): | |
max_new_tokens: int = Field(128, description="Maximum new tokens to generate") | |
batch_size: int = Field(1, description="Batch size for inference") | |
log_level: str = Field("info", description="Logging level") | |
visual_engine_dir: Optional[str] = Field(None, description="Directory for visual engine files") | |
visual_engine_name: str = Field("model.engine", description="Visual engine filename") | |
llm_engine_dir: Optional[str] = Field(None, description="Directory for LLM engine files") | |
hf_model_dir: Optional[str] = Field(None, description="Hugging Face model directory") | |
input_text: Optional[str] = Field(None, description="Input text for inference") | |
num_beams: int = Field(1, description="Number of beams for beam search") | |
top_k: int = Field(1, description="Top-k sampling value") | |
top_p: float = Field(0.0, description="Top-p (nucleus) sampling value") | |
temperature: float = Field(1.0, description="Sampling temperature") | |
repetition_penalty: float = Field(1.0, description="Repetition penalty factor") | |
run_profiling: bool = Field(False, description="Enable profiling mode") | |
profiling_iterations: int = Field(20, description="Number of profiling iterations") | |
check_accuracy: bool = Field(False, description="Enable accuracy checking") | |
video_path: Optional[str] = Field(None, description="Path to input video file") | |
video_num_frames: Optional[int] = Field(None, description="Number of video frames to process") | |
image_path: Optional[str] = Field(None, description="Path to input image file") | |
path_sep: str = Field(",", description="Path separator character") | |
prompt_sep: str = Field(",", description="Prompt separator character") | |
enable_context_fmha_fp32_acc: Optional[bool] = Field( | |
None, | |
description="Enable FP32 accumulation for context FMHA" | |
) | |
enable_chunked_context: bool = Field(False, description="Enable chunked context processing") | |
use_py_session: bool = Field(False, description="Use Python session instead of C++") | |
kv_cache_free_gpu_memory_fraction: float = Field( | |
0.9, | |
description="Fraction of GPU memory free for KV cache", | |
ge=0.0, le=1.0 | |
) | |
cross_kv_cache_fraction: float = Field( | |
0.5, | |
description="Fraction of cross-attention KV cache", | |
ge=0.0, le=1.0 | |
) | |
multi_block_mode: bool = Field(True, description="Enable multi-block processing mode") | |
class DolphinRunner(MultimodalModelRunner): | |
def __init__(self, args): | |
self.args = args | |
self.runtime_rank = mpi_rank() | |
device_id = self.runtime_rank % torch.cuda.device_count() | |
torch.cuda.set_device(device_id) | |
self.device = "cuda:%d" % (device_id) | |
self.stream = torch.cuda.Stream(torch.cuda.current_device()) | |
torch.cuda.set_stream(self.stream) | |
# parse model type from visual engine config | |
with open(os.path.join(self.args.visual_engine_dir, "config.json"), | |
"r") as f: | |
config = json.load(f) | |
self.model_type = config['builder_config']['model_type'] | |
self.vision_precision = config['builder_config']['precision'] | |
self.decoder_llm = not ( | |
't5' in self.model_type | |
or self.model_type in ['nougat', 'pix2struct'] | |
) # BLIP2-T5, pix2struct and Nougat are using encoder-decoder models as LLMs | |
if self.model_type == "mllama": | |
self.vision_input_names = [ | |
"pixel_values", | |
"aspect_ratio_ids", | |
"aspect_ratio_mask", | |
] | |
self.vision_output_names = [ | |
"output", | |
] | |
else: | |
self.vision_input_names = ["input"] | |
self.vision_output_names = ["output"] | |
self.use_py_session = True | |
self.init_image_encoder() | |
self.init_tokenizer() | |
self.init_processor() | |
self.init_llm() | |
def init_tokenizer(self): | |
assert self.model_type == 'nougat' | |
self.tokenizer = AutoTokenizer.from_pretrained(self.args.hf_model_dir) | |
self.tokenizer.padding_side = "right" | |
def init_processor(self): | |
assert self.model_type == 'nougat' | |
self.processor = DonutProcessor.from_pretrained(self.args.hf_model_dir, use_fast=True) | |
def run(self, input_texts, input_images, max_new_tokens): | |
prompts = [f"<s>{text.strip()} <Answer/>" for text in input_texts] | |
images = self.processor(input_images, return_tensors="pt")['pixel_values'].to("cuda") | |
prompt_ids = self.tokenizer(prompts, add_special_tokens=False, return_tensors="pt").input_ids.to("cuda") | |
# π¨π¨π¨ Important! If the type of prompt_ids is not int32, the output will be wrong. π¨π¨π¨ | |
prompt_ids = prompt_ids.to(torch.int32) | |
logger.info("---------------------------------------------------------") | |
logger.info(f"images size: {images.size()}") | |
logger.info(f"prompt_ids: {prompt_ids}, size: {prompt_ids.size()}, dtype: {prompt_ids.dtype}") | |
logger.info("---------------------------------------------------------") | |
output_texts = self.generate(input_texts, | |
[None] * len(input_texts), | |
images, | |
prompt_ids, | |
max_new_tokens, | |
warmup=False, | |
) | |
return output_texts | |
def generate(self, | |
pre_prompt, | |
post_prompt, | |
image, | |
decoder_input_ids, | |
max_new_tokens, | |
warmup=False, | |
other_vision_inputs={}, | |
other_decoder_inputs={}): | |
if not warmup: | |
profiler.start("Generate") | |
input_ids, input_lengths, ptuning_args, visual_features = self.preprocess( | |
warmup, pre_prompt, post_prompt, image, other_vision_inputs) | |
if warmup: return None | |
# use prompt tuning to pass multimodal features | |
# model.generate() expects the following params (see layers/embedding.py): | |
# args[0]: prompt embedding table, [batch_size, multimodal_len, hidden_size], later flattened to [batch_size * multimodal_len, hidden_size] | |
# args[1]: prompt task ids, [batch_size]. in multimodal case, arange(batch_size), i.e. in VILA batching mode 2, each image is treated separately in the batch instead of concated together (although the prompt embedding table has to be concated) | |
# args[2]: prompt task vocab size, [1]. assuming all table has the same length, which in multimodal case equals to multimodal_len | |
profiler.start("LLM") | |
if self.model_type in ['nougat', 'pix2struct']: | |
# Trim encoder input_ids to match visual features shape | |
ids_shape = (min(self.args.batch_size, len(pre_prompt)), visual_features.shape[1]) | |
if self.model_type == 'nougat': | |
input_ids = torch.zeros(ids_shape, dtype=torch.int32) | |
elif self.model_type == 'pix2struct': | |
input_ids = torch.ones(ids_shape, dtype=torch.int32) | |
output_ids = self.model.generate( | |
input_ids, | |
decoder_input_ids, | |
max_new_tokens, | |
num_beams=self.args.num_beams, | |
bos_token_id=self.tokenizer.bos_token_id, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id, | |
debug_mode=False, | |
prompt_embedding_table=ptuning_args[0], | |
prompt_tasks=ptuning_args[1], | |
prompt_vocab_size=ptuning_args[2], | |
) | |
profiler.stop("LLM") | |
if mpi_rank() == 0: | |
# Extract a list of tensors of shape beam_width x output_ids. | |
output_beams_list = [ | |
self.tokenizer.batch_decode( | |
output_ids[batch_idx, :, decoder_input_ids.shape[1]:], | |
skip_special_tokens=False) for batch_idx in range( | |
min(self.args.batch_size, decoder_input_ids.shape[0])) | |
] | |
stripped_text = [[ | |
output_beams_list[batch_idx][beam_idx].replace("</s>", "").replace("<pad>", "").strip() | |
for beam_idx in range(self.args.num_beams) | |
] for batch_idx in range( | |
min(self.args.batch_size, decoder_input_ids.shape[0]))] | |
profiler.stop("Generate") | |
return stripped_text | |
else: | |
profiler.stop("Generate") | |
return None | |
if __name__ == "__main__": | |
config = InferenceConfig( | |
max_new_tokens=4024, | |
batch_size=16, | |
log_level="info", | |
hf_model_dir=f"./tmp/hf_models/Dolphin", | |
visual_engine_dir=f"./tmp/trt_engines/Dolphin/vision_encoder", | |
llm_engine_dir=f"./tmp/trt_engines/Dolphin/1-gpu/bfloat16", | |
) | |
model = DolphinRunner(config) | |
image_path = "../../demo/page_imgs/page_1.jpeg" | |
prompt = "Parse the reading order of this document." | |
image = Image.open(image_path).convert("RGB") | |
output_texts = model.run([prompt], [image], 4024) | |
output_texts = [texts[0] for texts in output_texts] | |
print(output_texts) | |