DaniAffCH commited on
Commit
0a88ce4
·
1 Parent(s): 6911e5d

[GSoC] Gemm and MatMul block quantization support (#268)

Browse files

* Gemm and MatMul block quantization support

* refactoring

* fix indentation

* node name independent

Files changed (1) hide show
  1. tools/quantize/block_quantize.py +134 -114
tools/quantize/block_quantize.py CHANGED
@@ -14,12 +14,10 @@ import numpy as np
14
  import onnx
15
  from onnx import helper
16
 
17
- BITS_TO_NUMPY_TYPE = {8: np.uint8, 16: np.uint16}
18
 
19
 
20
- SUPPORTED_OPS = {
21
- "Conv"
22
- }
23
 
24
  ONNX_OPSET = 21
25
 
@@ -43,12 +41,6 @@ class BlockQuantizeResult:
43
  quantization_error: np.ndarray = field(default_factory=lambda: np.array([]))
44
 
45
 
46
- @dataclass
47
- class LayerParams:
48
- weights: np.ndarray = field(default_factory=lambda: np.array([]))
49
- bias: Optional[np.ndarray] = None
50
-
51
-
52
  def closest_divisor(number: int, divisor: int) -> int:
53
  for d in range(divisor, 0, -1):
54
  if number % d == 0:
@@ -169,18 +161,6 @@ class BlockQuantizer:
169
 
170
  return None
171
 
172
- def get_layer_params(self, node: onnx.NodeProto) -> LayerParams:
173
- params = LayerParams()
174
-
175
- weights_name = node.input[1]
176
- params.weights = self.get_initializer_tensor(weights_name)
177
-
178
- if len(node.input) > 2:
179
- bias_name = node.input[2]
180
- params.bias = self.get_initializer_tensor(bias_name)
181
-
182
- return params
183
-
184
  def compute_scale_zeropoint(
185
  self, b_min: np.ndarray, b_max: np.ndarray
186
  ) -> Tuple[np.ndarray, np.ndarray]:
@@ -208,24 +188,28 @@ class BlockQuantizer:
208
 
209
  def block_quantize(self, weight: np.ndarray) -> BlockQuantizeResult:
210
  original_shape = weight.shape
211
- weight = weight.reshape((weight.shape[0], -1))
212
 
213
- quantization_axis = 1
 
 
 
 
214
 
215
- block_size = closest_divisor(weight.shape[1], self.conf.block_size)
 
 
216
 
217
  assert (
218
- weight.shape[1] % block_size == 0
219
- ), f"weight shape ({weight.shape[1]}) must be divisible by block size ({block_size})"
220
 
221
- # Warning, axis = 1 specific instruction!
222
- blocked_weight = weight.reshape(
223
- (weight.shape[0], weight.shape[1] // block_size, -1)
224
- )
 
225
 
226
- # Warning, axis = 1 specific instruction!
227
  blocked_max = np.max(blocked_weight, -1)
228
- # Warning, axis = 1 specific instruction!
229
  blocked_min = np.min(blocked_weight, -1)
230
 
231
  scales, zeropoints = self.compute_scale_zeropoint(blocked_min, blocked_max)
@@ -273,93 +257,129 @@ class BlockQuantizer:
273
  def run(self):
274
  print("Quantizing the model...")
275
 
276
- visited_nodes = []
277
  sqe = []
278
 
279
- for node in self.model.graph.node:
280
- if node.name in visited_nodes:
281
- continue
 
 
282
  if node.op_type in SUPPORTED_OPS:
283
- conv_params = self.get_layer_params(node)
284
- block_quantize_res = self.block_quantize(conv_params.weights)
285
-
286
- quantized_weights_name = f"{node.name}_quantized_weights"
287
- quantized_node_name = f"{node.name}_quantized_node"
288
- dequantized_weights_name = f"{node.name}_dequantized_weights"
289
- scales_name = f"{node.name}_scales"
290
- zero_point_name = f"{node.name}_zero_point"
291
-
292
- shape_node_name = f"{node.name}_shape_node"
293
- shape_name = f"{node.name}_shape"
294
- reshaped_weights_name = f"{node.name}_reshaped_weights"
295
-
296
- dequantize_node = create_dequantize_node(
297
- quantized_node_name,
298
- quantized_weights_name,
299
- scales_name,
300
- zero_point_name,
301
- dequantized_weights_name,
302
- block_quantize_res.block_size,
303
- block_quantize_res.axis,
304
- )
305
- reshape_node = create_reshape_node(
306
- shape_node_name,
307
- dequantized_weights_name,
308
- shape_name,
309
- reshaped_weights_name,
310
- )
311
-
312
- shape_tensor = onnx.numpy_helper.from_array(
313
- np.array(block_quantize_res.original_shape), name=shape_name
314
- )
315
- scale_initializer = onnx.numpy_helper.from_array(
316
- block_quantize_res.scales, name=scales_name
317
- )
318
- zero_point_initializer = onnx.numpy_helper.from_array(
319
- block_quantize_res.zero_point, name=zero_point_name
320
- )
321
- quantized_weights_initializer = onnx.numpy_helper.from_array(
322
- block_quantize_res.quantized_weights, name=quantized_weights_name
323
- )
324
-
325
- dequantized_weights_info = helper.make_tensor_value_info(
326
- dequantized_weights_name,
327
- onnx.TensorProto.FLOAT,
328
- block_quantize_res.quantized_weights.shape,
329
- )
330
- shape_info = helper.make_tensor_value_info(
331
- reshaped_weights_name,
332
- onnx.TensorProto.FLOAT,
333
- block_quantize_res.original_shape,
334
- )
335
-
336
- self.graph.initializer.extend(
337
- [
338
- scale_initializer,
339
- zero_point_initializer,
340
- shape_tensor,
341
- quantized_weights_initializer,
342
- ]
343
- )
344
-
345
- # Removing fp32 weights
346
- self.graph.initializer.remove(
347
- next(
348
- init
349
- for init in self.graph.initializer
350
- if init.name == node.input[1]
351
  )
352
- )
353
- node.input[1] = reshaped_weights_name
354
 
355
- # Preserving the topological order of graph nodes
356
- self.graph.node.insert(0, reshape_node)
357
- self.graph.node.insert(0, dequantize_node)
358
- self.graph.value_info.insert(0, shape_info)
359
- self.graph.value_info.insert(0, dequantized_weights_info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
- sqe.append(block_quantize_res.quantization_error**2)
362
- visited_nodes.append(node.name)
 
363
 
364
  onnx.checker.check_model(self.model, full_check=True)
365
  onnx.save(self.model, self.conf.output_model_path)
 
14
  import onnx
15
  from onnx import helper
16
 
17
+ BITS_TO_NUMPY_TYPE = {8: np.int8, 16: np.int16}
18
 
19
 
20
+ SUPPORTED_OPS = {"Conv", "Gemm", "MatMul"}
 
 
21
 
22
  ONNX_OPSET = 21
23
 
 
41
  quantization_error: np.ndarray = field(default_factory=lambda: np.array([]))
42
 
43
 
 
 
 
 
 
 
44
  def closest_divisor(number: int, divisor: int) -> int:
45
  for d in range(divisor, 0, -1):
46
  if number % d == 0:
 
161
 
162
  return None
163
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  def compute_scale_zeropoint(
165
  self, b_min: np.ndarray, b_max: np.ndarray
166
  ) -> Tuple[np.ndarray, np.ndarray]:
 
188
 
189
  def block_quantize(self, weight: np.ndarray) -> BlockQuantizeResult:
190
  original_shape = weight.shape
 
191
 
192
+ if weight.ndim > 1:
193
+ weight = weight.reshape((weight.shape[0], -1))
194
+ quantization_axis = 1
195
+ else:
196
+ quantization_axis = 0
197
 
198
+ block_size = closest_divisor(
199
+ weight.shape[quantization_axis], self.conf.block_size
200
+ )
201
 
202
  assert (
203
+ weight.shape[quantization_axis] % block_size == 0
204
+ ), f"weight shape ({weight.shape[quantization_axis]}) must be divisible by block size ({block_size})"
205
 
206
+ # Flattening the tensor after the quantization axis
207
+ new_shape = list(weight.shape[: quantization_axis + 1]) + [-1]
208
+ new_shape[quantization_axis] = new_shape[quantization_axis] // block_size
209
+
210
+ blocked_weight = weight.reshape(new_shape)
211
 
 
212
  blocked_max = np.max(blocked_weight, -1)
 
213
  blocked_min = np.min(blocked_weight, -1)
214
 
215
  scales, zeropoints = self.compute_scale_zeropoint(blocked_min, blocked_max)
 
257
  def run(self):
258
  print("Quantizing the model...")
259
 
260
+ quantized_inputs = []
261
  sqe = []
262
 
263
+ node_idx = 0
264
+
265
+ while node_idx < len(self.model.graph.node):
266
+ node = self.model.graph.node[node_idx]
267
+
268
  if node.op_type in SUPPORTED_OPS:
269
+ for input_idx, input_name in enumerate(node.input):
270
+ weight = self.get_initializer_tensor(input_name)
271
+
272
+ quantized_weights_name = f"{input_name}_quantized"
273
+ quantized_node_name = f"{input_name}_quantized_node"
274
+ dequantized_weights_name = f"{input_name}_dequantized"
275
+ scales_name = f"{input_name}_scales"
276
+ zero_point_name = f"{input_name}_zero_point"
277
+
278
+ shape_node_name = f"{input_name}_shape_node"
279
+ shape_name = f"{input_name}_shape"
280
+ reshaped_weights_name = f"{input_name}_reshaped"
281
+
282
+ # Skip quantization if weights are taken as external input
283
+ # or if they don't contain enough elements to create at least 1 block
284
+ if weight is None or weight.size < self.conf.block_size:
285
+ continue
286
+
287
+ reshape_needed = weight.ndim > 2
288
+
289
+ # In case of parameter sharing
290
+ if input_name in quantized_inputs:
291
+ node.input[input_idx] = (
292
+ reshaped_weights_name
293
+ if reshape_needed
294
+ else dequantized_weights_name
295
+ )
296
+ continue
297
+
298
+ quantized_inputs.append(input_name)
299
+ block_quantize_res = self.block_quantize(weight)
300
+
301
+ dequantize_node = create_dequantize_node(
302
+ quantized_node_name,
303
+ quantized_weights_name,
304
+ scales_name,
305
+ zero_point_name,
306
+ dequantized_weights_name,
307
+ block_quantize_res.block_size,
308
+ block_quantize_res.axis,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  )
 
 
310
 
311
+ if reshape_needed:
312
+ reshape_node = create_reshape_node(
313
+ shape_node_name,
314
+ dequantized_weights_name,
315
+ shape_name,
316
+ reshaped_weights_name,
317
+ )
318
+
319
+ shape_tensor = onnx.numpy_helper.from_array(
320
+ np.array(block_quantize_res.original_shape), name=shape_name
321
+ )
322
+ scale_initializer = onnx.numpy_helper.from_array(
323
+ block_quantize_res.scales, name=scales_name
324
+ )
325
+ zero_point_initializer = onnx.numpy_helper.from_array(
326
+ block_quantize_res.zero_point, name=zero_point_name
327
+ )
328
+ quantized_weights_initializer = onnx.numpy_helper.from_array(
329
+ block_quantize_res.quantized_weights,
330
+ name=quantized_weights_name,
331
+ )
332
+
333
+ dequantized_weights_info = helper.make_tensor_value_info(
334
+ dequantized_weights_name,
335
+ onnx.TensorProto.FLOAT,
336
+ block_quantize_res.quantized_weights.shape,
337
+ )
338
+
339
+ if reshape_needed:
340
+ shape_info = helper.make_tensor_value_info(
341
+ reshaped_weights_name,
342
+ onnx.TensorProto.FLOAT,
343
+ block_quantize_res.original_shape,
344
+ )
345
+
346
+ self.graph.initializer.extend(
347
+ [
348
+ scale_initializer,
349
+ zero_point_initializer,
350
+ shape_tensor,
351
+ quantized_weights_initializer,
352
+ ]
353
+ )
354
+
355
+ # Removing fp32 weights
356
+ self.graph.initializer.remove(
357
+ next(
358
+ init
359
+ for init in self.graph.initializer
360
+ if init.name == input_name
361
+ )
362
+ )
363
+
364
+ node.input[input_idx] = (
365
+ reshaped_weights_name
366
+ if reshape_needed
367
+ else dequantized_weights_name
368
+ )
369
+
370
+ # Preserving graph nodes topological order
371
+ if reshape_needed:
372
+ self.graph.node.insert(0, reshape_node)
373
+ node_idx += 1
374
+
375
+ self.graph.node.insert(0, dequantize_node)
376
+ node_idx += 1
377
+ self.graph.value_info.insert(0, shape_info)
378
+ self.graph.value_info.insert(0, dequantized_weights_info)
379
 
380
+ sqe.append(block_quantize_res.quantization_error**2)
381
+
382
+ node_idx += 1
383
 
384
  onnx.checker.check_model(self.model, full_check=True)
385
  onnx.save(self.model, self.conf.output_model_path)