opencv_zoo / tools /quantize /block_quantize.py
DaniAffCH's picture
[GSoC] Gemm and MatMul block quantization support (#268)
0a88ce4
raw
history blame
14.6 kB
import sys
MIN_PYTHON_VERSION = (3, 7)
if sys.version_info < MIN_PYTHON_VERSION:
raise ImportError("This script requires Python 3.7 or higher!")
import argparse
import os
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
import numpy as np
import onnx
from onnx import helper
BITS_TO_NUMPY_TYPE = {8: np.int8, 16: np.int16}
SUPPORTED_OPS = {"Conv", "Gemm", "MatMul"}
ONNX_OPSET = 21
@dataclass
class BlockQuantizeConfig:
input_model_path: str
output_model_path: str
block_size: int
bits: int
@dataclass
class BlockQuantizeResult:
quantized_weights: np.ndarray = field(default_factory=lambda: np.array([]))
scales: np.ndarray = field(default_factory=lambda: np.array([]))
zero_point: np.ndarray = field(default_factory=lambda: np.array([]))
block_size: int = 1
axis: int = 1
original_shape: Tuple = field(default_factory=tuple)
quantization_error: np.ndarray = field(default_factory=lambda: np.array([]))
def closest_divisor(number: int, divisor: int) -> int:
for d in range(divisor, 0, -1):
if number % d == 0:
return d
return 1
def block_dequantize_tensor(
x: np.ndarray, block_axis: int, scale: np.ndarray, zero_point: np.ndarray
) -> np.ndarray:
repeats = x.shape[block_axis] // scale.shape[block_axis]
x_scale_elementwise = np.repeat(scale, repeats=repeats, axis=block_axis)
x_zero_point_elementwise = np.repeat(zero_point, repeats=repeats, axis=block_axis)
y = (
x.astype(np.float32) - x_zero_point_elementwise.astype(np.float32)
) * x_scale_elementwise
return y
def block_quantize_tensor(
x: np.ndarray,
block_axis: int,
scale: np.ndarray,
zero_point: np.ndarray,
n_bits: int,
) -> np.ndarray:
repeats = x.shape[block_axis] // scale.shape[block_axis]
y_scale_elementwise = np.repeat(scale, repeats=repeats, axis=block_axis)
y_zero_point_elementwise = np.repeat(zero_point, repeats=repeats, axis=block_axis)
y = np.rint(x / y_scale_elementwise + y_zero_point_elementwise).astype(
BITS_TO_NUMPY_TYPE[n_bits]
)
return y
def create_dequantize_node(
node_name,
quantized_weights,
scales,
zero_point,
dequantized_weights,
block_size,
axis,
) -> onnx.NodeProto:
block_size_attr = helper.make_attribute("block_size", block_size)
axis_attr = helper.make_attribute("axis", axis)
n = helper.make_node(
"DequantizeLinear",
inputs=[quantized_weights, scales, zero_point],
outputs=[dequantized_weights],
name=node_name,
)
n.attribute.extend([block_size_attr, axis_attr])
return n
def create_reshape_node(
node_name, dequantized_weights, shape_tensor, reshaped_weights_name
) -> onnx.NodeProto:
return helper.make_node(
"Reshape",
inputs=[dequantized_weights, shape_tensor],
outputs=[reshaped_weights_name],
name=node_name,
)
class BlockQuantizer:
def __init__(self, conf: BlockQuantizeConfig) -> None:
self.conf = conf
self.validate_conf()
self.model = onnx.load(conf.input_model_path)
if self.model.opset_import[0].version != ONNX_OPSET:
self.model = onnx.version_converter.convert_version(self.model, ONNX_OPSET)
self.graph = self.model.graph
self.initializers_map = {
init.name: init for init in self.model.graph.initializer
}
def validate_conf(self):
if not os.path.isfile(self.conf.input_model_path):
raise ValueError(
f"Input model path '{self.conf.input_model_path}' does not exist or is not a file."
)
if not self.conf.input_model_path.lower().endswith(".onnx"):
raise ValueError(
f"Input model path '{self.conf.input_model_path}' must have a .onnx extension."
)
if not self.conf.output_model_path.lower().endswith(".onnx"):
raise ValueError(
f"Output model path '{self.conf.output_model_path}' must have a .onnx extension."
)
if self.conf.block_size <= 0:
raise ValueError("Block size must be a positive integer.")
if self.conf.bits not in BITS_TO_NUMPY_TYPE:
allowed_values = ", ".join([str(k) for k in BITS_TO_NUMPY_TYPE.keys()])
raise ValueError(
f"Bits must be one of the following values: [{allowed_values}]."
)
def get_initializer_tensor(self, name: str) -> Optional[np.ndarray]:
if name in self.initializers_map:
return onnx.numpy_helper.to_array(self.initializers_map[name])
return None
def compute_scale_zeropoint(
self, b_min: np.ndarray, b_max: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
assert (
b_min < b_max
).all(), (
"minimum must be lower than maximum when computing scale and zero point"
)
# zero must be present in the range, this enforces qmin <= zero_point <= qmax
b_min = np.minimum(b_min, np.zeros_like(b_min, dtype=b_min.dtype))
b_max = np.maximum(b_max, np.zeros_like(b_max, dtype=b_max.dtype))
qmin = np.iinfo(BITS_TO_NUMPY_TYPE[self.conf.bits]).min
qmax = np.iinfo(BITS_TO_NUMPY_TYPE[self.conf.bits]).max
dq = qmax - qmin
scales = (b_max - b_min) / dq
zeropoints = np.rint(qmin - b_min / scales).astype(
BITS_TO_NUMPY_TYPE[self.conf.bits]
)
return (scales, zeropoints)
def block_quantize(self, weight: np.ndarray) -> BlockQuantizeResult:
original_shape = weight.shape
if weight.ndim > 1:
weight = weight.reshape((weight.shape[0], -1))
quantization_axis = 1
else:
quantization_axis = 0
block_size = closest_divisor(
weight.shape[quantization_axis], self.conf.block_size
)
assert (
weight.shape[quantization_axis] % block_size == 0
), f"weight shape ({weight.shape[quantization_axis]}) must be divisible by block size ({block_size})"
# Flattening the tensor after the quantization axis
new_shape = list(weight.shape[: quantization_axis + 1]) + [-1]
new_shape[quantization_axis] = new_shape[quantization_axis] // block_size
blocked_weight = weight.reshape(new_shape)
blocked_max = np.max(blocked_weight, -1)
blocked_min = np.min(blocked_weight, -1)
scales, zeropoints = self.compute_scale_zeropoint(blocked_min, blocked_max)
quantized_weight = block_quantize_tensor(
weight, quantization_axis, scales, zeropoints, self.conf.bits
)
reconstructed_mat = block_dequantize_tensor(
quantized_weight, quantization_axis, scales, zeropoints
)
qerror = np.linalg.norm(reconstructed_mat - weight)
res = BlockQuantizeResult(
quantized_weight,
scales,
zeropoints,
block_size,
quantization_axis,
original_shape,
qerror,
)
return res
def get_model_size(self, model_path: str) -> float:
size_bytes = os.path.getsize(model_path)
size_mb = size_bytes / 1024
return size_mb
def display_summary(self, sqe: List):
mse = sum(sqe) / len(sqe)
original_model_size = self.get_model_size(self.conf.input_model_path)
quantized_model_size = self.get_model_size(self.conf.output_model_path)
print("Done! Results saved in", self.conf.output_model_path)
print("\nSummary of Results:\n")
print(f"{'Metric':<30} {'Value':<10}")
print(f"{'-'*40}")
print(f"{'Mean Squared Quantization Error':<30} {mse:.6f}")
print(f"{'Original Model Size (KB)':<31} {original_model_size:,.2f}")
print(f"{'Block-Quantized Model Size (KB)':<30} {quantized_model_size:,.2f}")
def run(self):
print("Quantizing the model...")
quantized_inputs = []
sqe = []
node_idx = 0
while node_idx < len(self.model.graph.node):
node = self.model.graph.node[node_idx]
if node.op_type in SUPPORTED_OPS:
for input_idx, input_name in enumerate(node.input):
weight = self.get_initializer_tensor(input_name)
quantized_weights_name = f"{input_name}_quantized"
quantized_node_name = f"{input_name}_quantized_node"
dequantized_weights_name = f"{input_name}_dequantized"
scales_name = f"{input_name}_scales"
zero_point_name = f"{input_name}_zero_point"
shape_node_name = f"{input_name}_shape_node"
shape_name = f"{input_name}_shape"
reshaped_weights_name = f"{input_name}_reshaped"
# Skip quantization if weights are taken as external input
# or if they don't contain enough elements to create at least 1 block
if weight is None or weight.size < self.conf.block_size:
continue
reshape_needed = weight.ndim > 2
# In case of parameter sharing
if input_name in quantized_inputs:
node.input[input_idx] = (
reshaped_weights_name
if reshape_needed
else dequantized_weights_name
)
continue
quantized_inputs.append(input_name)
block_quantize_res = self.block_quantize(weight)
dequantize_node = create_dequantize_node(
quantized_node_name,
quantized_weights_name,
scales_name,
zero_point_name,
dequantized_weights_name,
block_quantize_res.block_size,
block_quantize_res.axis,
)
if reshape_needed:
reshape_node = create_reshape_node(
shape_node_name,
dequantized_weights_name,
shape_name,
reshaped_weights_name,
)
shape_tensor = onnx.numpy_helper.from_array(
np.array(block_quantize_res.original_shape), name=shape_name
)
scale_initializer = onnx.numpy_helper.from_array(
block_quantize_res.scales, name=scales_name
)
zero_point_initializer = onnx.numpy_helper.from_array(
block_quantize_res.zero_point, name=zero_point_name
)
quantized_weights_initializer = onnx.numpy_helper.from_array(
block_quantize_res.quantized_weights,
name=quantized_weights_name,
)
dequantized_weights_info = helper.make_tensor_value_info(
dequantized_weights_name,
onnx.TensorProto.FLOAT,
block_quantize_res.quantized_weights.shape,
)
if reshape_needed:
shape_info = helper.make_tensor_value_info(
reshaped_weights_name,
onnx.TensorProto.FLOAT,
block_quantize_res.original_shape,
)
self.graph.initializer.extend(
[
scale_initializer,
zero_point_initializer,
shape_tensor,
quantized_weights_initializer,
]
)
# Removing fp32 weights
self.graph.initializer.remove(
next(
init
for init in self.graph.initializer
if init.name == input_name
)
)
node.input[input_idx] = (
reshaped_weights_name
if reshape_needed
else dequantized_weights_name
)
# Preserving graph nodes topological order
if reshape_needed:
self.graph.node.insert(0, reshape_node)
node_idx += 1
self.graph.node.insert(0, dequantize_node)
node_idx += 1
self.graph.value_info.insert(0, shape_info)
self.graph.value_info.insert(0, dequantized_weights_info)
sqe.append(block_quantize_res.quantization_error**2)
node_idx += 1
onnx.checker.check_model(self.model, full_check=True)
onnx.save(self.model, self.conf.output_model_path)
self.display_summary(sqe)
def setup_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Blockwise quantization tool")
parser.add_argument(
"-i",
"--input_model",
type=str,
help="The path of onnx model to quantize",
required=True,
)
parser.add_argument(
"-bs",
"--block_size",
type=int,
help="The maximum size of quantization block",
required=True,
)
parser.add_argument(
"-b",
"--bits",
type=int,
help="Quantization bits",
choices=[8, 16],
default=8,
required=False,
)
parser.add_argument(
"-o",
"--output_model",
type=str,
help="The output model path",
default="block_quantized_model.onnx",
required=False,
)
return parser.parse_args()
if __name__ == "__main__":
args = setup_args()
quantization_config = BlockQuantizeConfig(
input_model_path=args.input_model,
output_model_path=args.output_model,
block_size=args.block_size,
bits=args.bits,
)
quantizer = BlockQuantizer(quantization_config)
quantizer.run()