module owlite.nn.functions.fake_fp_quantize
This module provides a custom PyTorch function for fake FP8 quantization.
The BaseFakeFPQuantizeFunction
class is a PyTorch function that performs fake FP8 quantization. It takes in an input tensor, step size, zero point, quantization minimum, quantization maximum, and axis as inputs. The symbolic
method defines the symbolic computation graph for the function, which checks if the quantization minimum and maximum are valid for FP8 quantization. If valid, it calls the fp8_qdq_symbolic
function to perform the quantization and dequantization.
The fake_fp8_quantize
function performs the actual fake FP8 quantization. It takes in an input tensor, step size, zero point, quantization minimum, quantization maximum, and axis as inputs. It first adjusts the step size and zero point according to the axis, then performs the quantization by dividing the input by the step size and adding the zero point. The result is then clipped to the quantization minimum and maximum, and converted to the FP8 data type. Finally, the result is converted back to the original data type, subtracted by the zero point, and multiplied by the step size.
The fp8_qdq_symbolic
function defines the symbolic computation graph for the fake FP8 quantization. It takes in an input value, step size, zero point, and axis as inputs. It first casts the zero point to the FP8 data type, then performs the quantization using the QuantizeLinear
operator. The result is then dequantized using the DequantizeLinear
operator.
Note:
This implementation assumes that the quantization minimum and maximum are valid for FP8 quantization. It also assumes that the input tensor is a PyTorch tensor.
function fake_fp8_quantize
python
fake_fp8_quantize(
inputs: Tensor,
step_size: Tensor,
zero_point: Tensor,
quant_min: float,
quant_max: float,
axis: int | None = None
) → Tensor
Perform fake FP8 quantization on an input tensor.
Args:
inputs
(torch.Tensor
): The input tensor.step_size
(torch.Tensor
): The step size.zero_point
(torch.Tensor
): The zero point.quant_min
(float
): The quantization minimum.quant_max
(float
): The quantization maximum.axis
(int
, optional): The axis.
Returns:
Value | tuple[Value, ...]
: The quantized tensor.
function fp8_qdq_symbolic
python
fp8_qdq_symbolic(
g: GraphContext,
inputs: Value,
step_size: Value,
zero_point: Value,
axis: int | None
) → Value | tuple[Value, ]
Define the symbolic computation graph for fake FP8 quantization.
Args:
g
(jit_utils.GraphContext
): The graph context.inputs
(torch.Value
): The input value.step_size
(torch.Value
): The step size.zero_point
(torch.Value
): The zero point.axis
(int
, optional): The axis.
Returns:
Value | tuple[Value, ...]
: The output value.
class BaseFakeFPQuantizeFunction
An autograd function that performs fake FP quantization.
Static Methods: symbolic: Defines the symbolic computation graph for the function.
method symbolic
python
symbolic(
g: GraphContext,
inputs: Value,
step_size: Value,
zero_point: Value,
grad_scale: float,
quant_min: float,
quant_max: float,
axis: int | None
) → Value | tuple[Value, ]
Define the symbolic computation graph for the function.
Args:
g
(jit_utils.GraphContext
): The graph context.inputs
(torch.Value
): A tensor to quantize.step_size
(torch.Value
): The quantization scale, determining the magnitude of each quantization interval.zero_point
(torch.Tensor
): The quantization zero_point. It may be expressed as a float in the context of asymmetric quantization, while for symmetric quantization, it is fixed at 0.grad_scale
(float
): The gradient scale.quant_min
(float
): The lower bound of the quantized domain, specified as an integer.quant_max
(float
): The upper bound of the quantized domain in as an integer.axis
(int
, optional): Channel axis. Only used whenper_channel
isTrue
. Defaults to 0.
Returns:
The output value.
Updated: 2024-06-13T23:42:41