import sys MIN_PYTHON_VERSION = (3, 7) if sys.version_info < MIN_PYTHON_VERSION: raise ImportError("This script requires Python 3.7 or higher!") import argparse import os from dataclasses import dataclass, field from typing import List, Optional, Tuple import numpy as np import onnx from onnx import helper BITS_TO_NUMPY_TYPE = {8: np.int8, 16: np.int16} SUPPORTED_OPS = {"Conv", "Gemm", "MatMul"} ONNX_OPSET = 21 @dataclass class BlockQuantizeConfig: input_model_path: str output_model_path: str block_size: int bits: int @dataclass class BlockQuantizeResult: quantized_weights: np.ndarray = field(default_factory=lambda: np.array([])) scales: np.ndarray = field(default_factory=lambda: np.array([])) zero_point: np.ndarray = field(default_factory=lambda: np.array([])) block_size: int = 1 axis: int = 1 original_shape: Tuple = field(default_factory=tuple) quantization_error: np.ndarray = field(default_factory=lambda: np.array([])) def closest_divisor(number: int, divisor: int) -> int: for d in range(divisor, 0, -1): if number % d == 0: return d return 1 def block_dequantize_tensor( x: np.ndarray, block_axis: int, scale: np.ndarray, zero_point: np.ndarray ) -> np.ndarray: repeats = x.shape[block_axis] // scale.shape[block_axis] x_scale_elementwise = np.repeat(scale, repeats=repeats, axis=block_axis) x_zero_point_elementwise = np.repeat(zero_point, repeats=repeats, axis=block_axis) y = ( x.astype(np.float32) - x_zero_point_elementwise.astype(np.float32) ) * x_scale_elementwise return y def block_quantize_tensor( x: np.ndarray, block_axis: int, scale: np.ndarray, zero_point: np.ndarray, n_bits: int, ) -> np.ndarray: repeats = x.shape[block_axis] // scale.shape[block_axis] y_scale_elementwise = np.repeat(scale, repeats=repeats, axis=block_axis) y_zero_point_elementwise = np.repeat(zero_point, repeats=repeats, axis=block_axis) y = np.rint(x / y_scale_elementwise + y_zero_point_elementwise).astype( BITS_TO_NUMPY_TYPE[n_bits] ) return y def create_dequantize_node( node_name, quantized_weights, scales, zero_point, dequantized_weights, block_size, axis, ) -> onnx.NodeProto: block_size_attr = helper.make_attribute("block_size", block_size) axis_attr = helper.make_attribute("axis", axis) n = helper.make_node( "DequantizeLinear", inputs=[quantized_weights, scales, zero_point], outputs=[dequantized_weights], name=node_name, ) n.attribute.extend([block_size_attr, axis_attr]) return n def create_reshape_node( node_name, dequantized_weights, shape_tensor, reshaped_weights_name ) -> onnx.NodeProto: return helper.make_node( "Reshape", inputs=[dequantized_weights, shape_tensor], outputs=[reshaped_weights_name], name=node_name, ) class BlockQuantizer: def __init__(self, conf: BlockQuantizeConfig) -> None: self.conf = conf self.validate_conf() self.model = onnx.load(conf.input_model_path) if self.model.opset_import[0].version != ONNX_OPSET: self.model = onnx.version_converter.convert_version(self.model, ONNX_OPSET) self.graph = self.model.graph self.initializers_map = { init.name: init for init in self.model.graph.initializer } def validate_conf(self): if not os.path.isfile(self.conf.input_model_path): raise ValueError( f"Input model path '{self.conf.input_model_path}' does not exist or is not a file." ) if not self.conf.input_model_path.lower().endswith(".onnx"): raise ValueError( f"Input model path '{self.conf.input_model_path}' must have a .onnx extension." ) if not self.conf.output_model_path.lower().endswith(".onnx"): raise ValueError( f"Output model path '{self.conf.output_model_path}' must have a .onnx extension." ) if self.conf.block_size <= 0: raise ValueError("Block size must be a positive integer.") if self.conf.bits not in BITS_TO_NUMPY_TYPE: allowed_values = ", ".join([str(k) for k in BITS_TO_NUMPY_TYPE.keys()]) raise ValueError( f"Bits must be one of the following values: [{allowed_values}]." ) def get_initializer_tensor(self, name: str) -> Optional[np.ndarray]: if name in self.initializers_map: return onnx.numpy_helper.to_array(self.initializers_map[name]) return None def compute_scale_zeropoint( self, b_min: np.ndarray, b_max: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: assert ( b_min < b_max ).all(), ( "minimum must be lower than maximum when computing scale and zero point" ) # zero must be present in the range, this enforces qmin <= zero_point <= qmax b_min = np.minimum(b_min, np.zeros_like(b_min, dtype=b_min.dtype)) b_max = np.maximum(b_max, np.zeros_like(b_max, dtype=b_max.dtype)) qmin = np.iinfo(BITS_TO_NUMPY_TYPE[self.conf.bits]).min qmax = np.iinfo(BITS_TO_NUMPY_TYPE[self.conf.bits]).max dq = qmax - qmin scales = (b_max - b_min) / dq zeropoints = np.rint(qmin - b_min / scales).astype( BITS_TO_NUMPY_TYPE[self.conf.bits] ) return (scales, zeropoints) def block_quantize(self, weight: np.ndarray) -> BlockQuantizeResult: original_shape = weight.shape if weight.ndim > 1: weight = weight.reshape((weight.shape[0], -1)) quantization_axis = 1 else: quantization_axis = 0 block_size = closest_divisor( weight.shape[quantization_axis], self.conf.block_size ) assert ( weight.shape[quantization_axis] % block_size == 0 ), f"weight shape ({weight.shape[quantization_axis]}) must be divisible by block size ({block_size})" # Flattening the tensor after the quantization axis new_shape = list(weight.shape[: quantization_axis + 1]) + [-1] new_shape[quantization_axis] = new_shape[quantization_axis] // block_size blocked_weight = weight.reshape(new_shape) blocked_max = np.max(blocked_weight, -1) blocked_min = np.min(blocked_weight, -1) scales, zeropoints = self.compute_scale_zeropoint(blocked_min, blocked_max) quantized_weight = block_quantize_tensor( weight, quantization_axis, scales, zeropoints, self.conf.bits ) reconstructed_mat = block_dequantize_tensor( quantized_weight, quantization_axis, scales, zeropoints ) qerror = np.linalg.norm(reconstructed_mat - weight) res = BlockQuantizeResult( quantized_weight, scales, zeropoints, block_size, quantization_axis, original_shape, qerror, ) return res def get_model_size(self, model_path: str) -> float: size_bytes = os.path.getsize(model_path) size_mb = size_bytes / 1024 return size_mb def display_summary(self, sqe: List): mse = sum(sqe) / len(sqe) original_model_size = self.get_model_size(self.conf.input_model_path) quantized_model_size = self.get_model_size(self.conf.output_model_path) print("Done! Results saved in", self.conf.output_model_path) print("\nSummary of Results:\n") print(f"{'Metric':<30} {'Value':<10}") print(f"{'-'*40}") print(f"{'Mean Squared Quantization Error':<30} {mse:.6f}") print(f"{'Original Model Size (KB)':<31} {original_model_size:,.2f}") print(f"{'Block-Quantized Model Size (KB)':<30} {quantized_model_size:,.2f}") def run(self): print("Quantizing the model...") quantized_inputs = [] sqe = [] node_idx = 0 while node_idx < len(self.model.graph.node): node = self.model.graph.node[node_idx] if node.op_type in SUPPORTED_OPS: for input_idx, input_name in enumerate(node.input): weight = self.get_initializer_tensor(input_name) quantized_weights_name = f"{input_name}_quantized" quantized_node_name = f"{input_name}_quantized_node" dequantized_weights_name = f"{input_name}_dequantized" scales_name = f"{input_name}_scales" zero_point_name = f"{input_name}_zero_point" shape_node_name = f"{input_name}_shape_node" shape_name = f"{input_name}_shape" reshaped_weights_name = f"{input_name}_reshaped" # Skip quantization if weights are taken as external input # or if they don't contain enough elements to create at least 1 block if weight is None or weight.size < self.conf.block_size: continue reshape_needed = weight.ndim > 2 # In case of parameter sharing if input_name in quantized_inputs: node.input[input_idx] = ( reshaped_weights_name if reshape_needed else dequantized_weights_name ) continue quantized_inputs.append(input_name) block_quantize_res = self.block_quantize(weight) dequantize_node = create_dequantize_node( quantized_node_name, quantized_weights_name, scales_name, zero_point_name, dequantized_weights_name, block_quantize_res.block_size, block_quantize_res.axis, ) if reshape_needed: reshape_node = create_reshape_node( shape_node_name, dequantized_weights_name, shape_name, reshaped_weights_name, ) shape_tensor = onnx.numpy_helper.from_array( np.array(block_quantize_res.original_shape), name=shape_name ) scale_initializer = onnx.numpy_helper.from_array( block_quantize_res.scales, name=scales_name ) zero_point_initializer = onnx.numpy_helper.from_array( block_quantize_res.zero_point, name=zero_point_name ) quantized_weights_initializer = onnx.numpy_helper.from_array( block_quantize_res.quantized_weights, name=quantized_weights_name, ) dequantized_weights_info = helper.make_tensor_value_info( dequantized_weights_name, onnx.TensorProto.FLOAT, block_quantize_res.quantized_weights.shape, ) if reshape_needed: shape_info = helper.make_tensor_value_info( reshaped_weights_name, onnx.TensorProto.FLOAT, block_quantize_res.original_shape, ) self.graph.initializer.extend( [ scale_initializer, zero_point_initializer, shape_tensor, quantized_weights_initializer, ] ) # Removing fp32 weights self.graph.initializer.remove( next( init for init in self.graph.initializer if init.name == input_name ) ) node.input[input_idx] = ( reshaped_weights_name if reshape_needed else dequantized_weights_name ) # Preserving graph nodes topological order if reshape_needed: self.graph.node.insert(0, reshape_node) node_idx += 1 self.graph.node.insert(0, dequantize_node) node_idx += 1 self.graph.value_info.insert(0, shape_info) self.graph.value_info.insert(0, dequantized_weights_info) sqe.append(block_quantize_res.quantization_error**2) node_idx += 1 onnx.checker.check_model(self.model, full_check=True) onnx.save(self.model, self.conf.output_model_path) self.display_summary(sqe) def setup_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Blockwise quantization tool") parser.add_argument( "-i", "--input_model", type=str, help="The path of onnx model to quantize", required=True, ) parser.add_argument( "-bs", "--block_size", type=int, help="The maximum size of quantization block", required=True, ) parser.add_argument( "-b", "--bits", type=int, help="Quantization bits", choices=[8, 16], default=8, required=False, ) parser.add_argument( "-o", "--output_model", type=str, help="The output model path", default="block_quantized_model.onnx", required=False, ) return parser.parse_args() if __name__ == "__main__": args = setup_args() quantization_config = BlockQuantizeConfig( input_model_path=args.input_model, output_model_path=args.output_model, block_size=args.block_size, bits=args.bits, ) quantizer = BlockQuantizer(quantization_config) quantizer.run()