Spaces:
Runtime error
Runtime error
# copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.1/examples/enc_dec/helper.py | |
import typing | |
from typing import Union | |
import numpy as np | |
import torch # pytype: disable=import-error | |
from tensorrt_llm._utils import str_dtype_to_torch | |
def split(v: Union[np.ndarray, torch.Tensor], | |
tp_size: int, | |
tp_rank: int, | |
dim=0): | |
if tp_size == 1: | |
if isinstance(v, np.ndarray): | |
return np.ascontiguousarray(v.copy()) | |
else: | |
return v.clone().detach() | |
assert len(v.shape) > 1 or dim == 0 | |
if isinstance(v, np.ndarray): | |
return np.ascontiguousarray( | |
np.split(v, tp_size, axis=dim)[tp_rank].copy()) | |
else: | |
assert v.shape[dim] % tp_size == 0, \ | |
'Unable to split: shape={v.shape} (dim={dim}) tp_size={tp_size}.' | |
split_size = v.shape[dim] // tp_size | |
return v.split(split_size, dim=dim)[tp_rank].clone().detach() | |
def reshape(v: torch.Tensor, shape=None): | |
if shape is None: | |
return v.contiguous() | |
else: | |
return v.reshape(shape).contiguous() | |
def fuse_qkv_one_layer(params, attn_module_name, trtllm_layer_name, tp_size, | |
tp_rank, model_type, weight_shape, bias_shape): | |
qkv_module_names = get_qkv_module_name(model_type) | |
weight = {} | |
# fuse weights of q, k, v | |
q_w = params[f'{attn_module_name}.{qkv_module_names["q"]}.weight'] | |
k_w = params[f'{attn_module_name}.{qkv_module_names["k"]}.weight'] | |
v_w = params[f'{attn_module_name}.{qkv_module_names["v"]}.weight'] | |
# fuse qkv weight | |
shape = q_w.shape # (do, din) | |
qkv_w = torch.cat([q_w, k_w, v_w], | |
dim=0).reshape([3, shape[0], shape[1]]) # (3, do, din) | |
qkv_w = split(qkv_w, tp_size, tp_rank, dim=1) | |
weight[f'{trtllm_layer_name}.qkv.weight'] = reshape(qkv_w, | |
shape=weight_shape) | |
# fuse qkv biases if present | |
if f'{attn_module_name}.{qkv_module_names["q"]}.bias' in params.keys( | |
) and params[f'{attn_module_name}.{qkv_module_names["q"]}.bias'] is not None: | |
q_b = params[f'{attn_module_name}.{qkv_module_names["q"]}.bias'] | |
k_b = params[f'{attn_module_name}.{qkv_module_names["k"]}.bias'] | |
v_b = params[f'{attn_module_name}.{qkv_module_names["v"]}.bias'] | |
shape = q_b.shape[0] # (do,) | |
qkv_b = torch.cat([q_b, k_b, v_b], dim=0).reshape([3, shape]) # (3, do) | |
qkv_b = split(qkv_b, tp_size, tp_rank, dim=1) | |
weight[f'{trtllm_layer_name}.qkv.bias'] = reshape(qkv_b, | |
shape=bias_shape) | |
return weight | |
def get_qkv_module_name(model_type): | |
if model_type in ["t5", "blip2"]: | |
q = "q" | |
k = "k" | |
v = "v" | |
elif model_type == "bart" or model_type == "nmt": | |
q = "q_proj" | |
k = "k_proj" | |
v = "v_proj" | |
elif model_type == "pix2struct": | |
q = "query" | |
k = "key" | |
v = "value" | |
return {"q": q, "k": k, "v": v} | |
def convert_weight_to_dtype(params: typing.Dict[str, torch.Tensor], | |
dtype: typing.Optional[np.dtype] = None): | |
if dtype is not None: | |
assert isinstance(dtype, | |
str), f"dtype must be str, but get type {type(dtype)}" | |
for name in params.keys(): | |
params[name] = params[name].to(str_dtype_to_torch(dtype)) | |