Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola committed Sep 19, 2024
1 parent 0bd92b7 commit 461c6eb
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 25 deletions.
12 changes: 12 additions & 0 deletions src/python/py/models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,18 @@ python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o p
python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options use_8bits_moe=1
```

#### Use QDQ Pattern for Quantization

This scenario is for when you want to use the QDQ pattern (DequantizeLinear + MatMul) instead of the MatMulNBits operator when quantizing the model to 4 bits.

```
# From wheel:
python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options use_qdq=1
# From source:
python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options use_qdq=1
```

### Unit Testing Models

This scenario is where your PyTorch model is already downloaded locally (either in the default Hugging Face cache directory or in a local folder on disk). If it is not already downloaded locally, here is an example of how you can download it.
Expand Down
52 changes: 27 additions & 25 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,55 +721,57 @@ def make_matmul_int4(self, matmul, basename, root_input, **kwargs):
self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', matmul.out_features])

return name

def make_matmul_int4_qdq(self, matmul, matmul_name, root_input, **kwargs):
if not hasattr(matmul, "qweight"):
# TODO: quantize weights, then save new MatMul numpy weights for onnx model
# print(f"Quantizing to {self.onnx_dtype} on-the-fly is not currently supported.")
# print(f"Saving as {self.io_dtype} on-the-fly and quantizing to {self.onnx_dtype} at the end.")
return self.make_matmul_fp16_or_fp32(matmul, matmul_name, root_input, **kwargs)

dequantize_name = f"{matmul_name}/DequantizeLinear"

def make_dequantize_linear(self, dequantize_name, quantized_op):
# Input weights are quantized, save quantized MatMul numpy weights for onnx model
qweight = dequantize_name[1:].replace("/", ".") + ".qweight"
qweight_npy = matmul.qweight.detach().numpy()
qweight_npy = qweight_npy.reshape(qweight_npy.shape[0], qweight_npy.shape[1] * qweight_npy.shape[2])
qweight_npy = quantized_op.qweight.detach().numpy()
qweight_npy = qweight_npy.reshape(*qweight_npy.shape[:-2], qweight_npy.shape[-2] * qweight_npy.shape[-1])
self.make_external_tensor(qweight_npy, qweight, True)

scales = dequantize_name[1:].replace("/", ".") + ".scales"
scales_npy = matmul.scales.detach().numpy().astype(self.to_numpy_dtype[self.io_dtype])
scales_npy = scales_npy.reshape(*qweight_npy.shape[:-1], qweight_npy.shape[-1] * 2 // matmul.group_size)
scales_npy = quantized_op.scales.detach().numpy().astype(self.to_numpy_dtype[self.io_dtype])
scales_npy = scales_npy.reshape(*qweight_npy.shape[:-1], qweight_npy.shape[-1] * 2 // quantized_op.group_size)
self.make_external_tensor(scales_npy, scales)

dequantize_inputs = [qweight, scales]

if hasattr(matmul, "qzeros") and matmul.qzeros is not None:
if hasattr(quantized_op, "qzeros") and quantized_op.qzeros is not None:
zeros = dequantize_name[1:].replace("/", ".") + ".qzeros"
zeros_npy = matmul.qzeros.detach().numpy()
zeros_npy = zeros_npy.reshape(*qweight_npy.shape[:-1], qweight_npy.shape[-1] // matmul.group_size)
zeros_npy = quantized_op.qzeros.detach().numpy()
zeros_npy = zeros_npy.reshape(*qweight_npy.shape[:-1], qweight_npy.shape[-1] // quantized_op.group_size)
self.make_external_tensor(zeros_npy, zeros, True)
dequantize_inputs.append(zeros)

dequantize_output = f"{dequantize_name}/output_0"
self.make_node("DequantizeLinear", inputs=dequantize_inputs, outputs=[dequantize_output], name=dequantize_name, block_size=matmul.group_size)
self.make_value_info(dequantize_output, self.io_dtype, shape=[*scales_npy.shape[:-1], scales_npy.shape[-1] * matmul.group_size])
self.make_node("DequantizeLinear", inputs=dequantize_inputs, outputs=[dequantize_output], name=dequantize_name, block_size=quantized_op.group_size)
self.make_value_info(dequantize_output, self.io_dtype, shape=[*scales_npy.shape[:-1], scales_npy.shape[-1] * quantized_op.group_size])

return dequantize_output

def make_matmul_int4_qdq(self, matmul, matmul_name, root_input, **kwargs):
if not hasattr(matmul, "qweight"):
# TODO: quantize weights, then save new MatMul numpy weights for onnx model
# print(f"Quantizing to {self.onnx_dtype} on-the-fly is not currently supported.")
# print(f"Saving as {self.io_dtype} on-the-fly and quantizing to {self.onnx_dtype} at the end.")
return self.make_matmul_fp16_or_fp32(matmul, matmul_name, root_input, **kwargs)

dequantize_output = self.make_dequantize_linear(f"{matmul_name}/DequantizeLinear", matmul)

# Add a transpose instead of transposing the weights offline. The reason for this is that it is more natural and usually more performant to
# compute quantized matmul when the weights are transposed. In most implementations, the transpose should usually be converted to a "transposeB"
# attribute on the MatMul itself. A more natural way to represent this would have been to use Gemm since it already supports a transB attribute,
# but unfortunately Gemm doesn't support batches.
perms = list(range(0, len(scales_npy.shape)))
qweight_shape = matmul.scales.detach().numpy().shape
qweight_shape = [*qweight_shape[:-2], qweight_shape[-2] * qweight_shape[-1] * 2]
perms = list(range(0, len(qweight_shape)))
perms[-2], perms[-1] = perms[-1], perms[-2]
transposed_shape = [*scales_npy.shape[:-2], scales_npy.shape[-1] * matmul.group_size, scales_npy.shape[-2]]

transposed_shape = [*qweight_shape[:-2], qweight_shape[-1] * matmul.group_size, qweight_shape[-2]]
transpose_name = f"{matmul_name}/Transpose"
transpose_output = f"{transpose_name}/output_0"
self.make_node("Transpose", inputs=[dequantize_output], outputs=[transpose_output], name=transpose_name, perm=perms)
self.make_value_info(transpose_output, self.io_dtype, shape=transposed_shape)
self.make_transpose(transpose_name, dequantize_output, self.io_dtype, transposed_shape, perms)

matmul_output = "logits" if kwargs.get("logits", False) else f"{matmul_name}/output_0"
self.make_node("MatMul", inputs=[root_input, transpose_output], outputs=[matmul_output], name=matmul_name)
self.make_node("MatMul", inputs=[root_input, f"{transpose_name}/output_0"], outputs=[matmul_output], name=matmul_name)
self.make_value_info(matmul_output, self.io_dtype, shape=['batch_size', 'sequence_length', matmul.out_features])

return matmul_name
Expand Down

0 comments on commit 461c6eb

Please sign in to comment.