module owlite.nn.functions.fake_quantize
function fake_quantize
python
fake_quantize(
inputs: Tensor,
step_size: Tensor,
zero_point: Tensor,
quant_min: int,
quant_max: int,
axis: int | None = None
) → Tensor
Apply fake quantization function to the input with given quantization parameters.
Equivalent to torch.fake_quantize_per_channel_affine
if per_channel
is True
, torch.fake_quantize_per_tensor_affine
otherwise. In OwLite, quantization is simulated through the following mathematical expression:
$$$ \small
\text{FakeQuantize}(\text{input})=
\left(
\text{clip} \left(
{\biggl\lfloor \frac{\text{input}}{\text{step\_size}} \biggr\rceil } + \text{zero\_point},
\text{quant\_min},
\text{quant\_max}
\right) - \text{zero\_point}
\right) \times \text{step\_size}
$$$
The primary objective of exporting to the Open Neural Network Exchange (ONNX) format is to facilitate deployment on TensorRT rather than the ONNX runtime. Consequently, the export process is confined to transforming the model into a format compatible with TensorRT, specifically one that supports fake quantization. The incorporation of fake quantization involves the decomposition of the model into QuantizeLinear
and DequantizeLinear
operations within the ONNX specification. Subsequently, TensorRT is entrusted with the task of ingesting the resultant ONNX graph and executing it in INT8 format, optimizing the process to the fullest extent of its capabilities. For more information, see the TensorRT Developer Guide's section on Explicit Quantization.
Args:
inputs
(torch.Tensor
): A tensor to quantize.step_size
(torch.Tensor
): The quantization scale, determining the magnitude of each quantization interval.zero_point
(torch.Tensor
): The quantization zero\_point. It may be expressed as a float in the context of asymmetric quantization, while for symmetric quantization, it is fixed at 0.quant_min
(int
): The lower bound of the quantized domain, specified as an integer.quant_max
(int
): The upper bound of the quantized domain in as an integer.axis
(int
, optional): Channel axis. Only used whenper_channel
isTrue
. Defaults to 0.
Returns:
torch.Tensor
: fake-quantized tensor
function int8_qdq_symbolic
python
int8_qdq_symbolic(
g: GraphContext,
inputs: Value,
step_size: Value,
zero_point: Value,
quant_min: int,
quant_max: int,
axis: int | None
) → Value | tuple[Value, ]
Define the symbolic computation graph for fake INT8 quantization.
Args:
g
(jit_utils.GraphContext
): The graph context.inputs
(torch.Value
): A tensor to quantize.step_size
(torch.Value
): The quantization scale, determining the magnitude of each quantization interval.zero_point
(torch.Value
): The quantization zero_point. It may be expressed as a float in the context of asymmetric quantization, while for symmetric quantization, it is fixed at 0.quant_min
(int
): The lower bound of the quantized domain, specified as an integer.quant_max
(int
): The upper bound of the quantized domain in as an integer.axis
(int
, optional): Channel axis. Only used whenper_channel
isTrue
. Defaults to 0.
Returns:
The output value.
class BaseFakeINTQuantizeFunction
An autograd function for fake INT quantization.
Static Methods: symbolic: Defines the symbolic computation graph for the function.
method symbolic
python
symbolic(
g: GraphContext,
inputs: Value,
step_size: Value,
zero_point: Value,
grad_scale: float,
quant_min: int,
quant_max: int,
axis: int | None
) → Value | tuple[Value, ]
Define the symbolic computation graph for the function.
Args:
g
(jit_utils.GraphContext
): The graph context.inputs
(torch.Value
): A tensor to quantize.step_size
(torch.Value
): The quantization scale, determining the magnitude of each quantization interval.zero_point
(torch.Value
): The quantization zero_point. It may be expressed as a float in the context of asymmetric quantization, while for symmetric quantization, it is fixed at 0.grad_scale
(float
): The gradient scale.quant_min
(int
): The lower bound of the quantized domain, specified as an integer.quant_max
(int
): The upper bound of the quantized domain in as an integer.axis
(int
, optional): Channel axis. Only used whenper_channel
isTrue
. Defaults to 0.
Returns:
The output value.
Updated: 2024-06-13T23:42:41