torch.quantization¶
This module implements the functions you call
directly to convert your model from FP32 to quantized form. For
example the prepare()
is used in post training
quantization to prepares your model for the calibration step and
convert()
actually converts the weights to int8 and
replaces the operations with their quantized counterparts. There are
other helper functions for things like quantizing the input to your
model and performing critical fusions like conv+relu.
Top-level quantization APIs¶
-
torch.quantization.
quantize
(model, run_fn, run_args, mapping=None, inplace=False)[source]¶ Quantize the input float model with post training static quantization.
First it will prepare the model for calibration, then it calls run_fn which will run the calibration step, after that we will convert the model to a quantized model.
- Parameters
model – input float model
run_fn – a calibration function for calibrating the prepared model
run_args – positional arguments for run_fn
inplace – carry out model transformations in-place, the original module is mutated
mapping – correspondence between original module types and quantized counterparts
- Returns
Quantized model.
-
torch.quantization.
quantize_dynamic
(model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False)[source]¶ Converts a float model to dynamic (i.e. weights-only) quantized model.
Replaces specified modules with dynamic weight-only quantized versions and output the quantized model.
For simplest usage provide dtype argument that can be float16 or qint8. Weight-only quantization by default is performed for layers with large weights size - i.e. Linear and RNN variants.
Fine grained control is possible with qconfig and mapping that act similarly to quantize(). If qconfig is provided, the dtype argument is ignored.
- Parameters
model – input model
qconfig_spec –
Either:
A dictionary that maps from name or type of submodule to quantization configuration, qconfig applies to all submodules of a given module unless qconfig for the submodules are specified (when the submodule already has qconfig attribute). Entries in the dictionary need to be QConfigDynamic instances.
A set of types and/or submodule names to apply dynamic quantization to, in which case the dtype argument is used to specify the bit-width
inplace – carry out model transformations in-place, the original module is mutated
mapping – maps type of a submodule to a type of corresponding dynamically quantized version with which the submodule needs to be replaced
-
torch.quantization.
quantize_qat
(model, run_fn, run_args, inplace=False)[source]¶ Do quantization aware training and output a quantized model
- Parameters
model – input model
run_fn – a function for evaluating the prepared model, can be a function that simply runs the prepared model or a training loop
run_args – positional arguments for run_fn
- Returns
Quantized model.
-
torch.quantization.
prepare
(model, inplace=False, allow_list=None, observer_non_leaf_module_list=None, prepare_custom_config_dict=None)[source]¶ Prepares a copy of the model for quantization calibration or quantization-aware training.
Quantization configuration should be assigned preemptively to individual submodules in .qconfig attribute.
The model will be attached with observer or fake quant modules, and qconfig will be propagated.
- Parameters
model – input model to be modified in-place
inplace – carry out model transformations in-place, the original module is mutated
allow_list – list of quantizable modules
observer_non_leaf_module_list – list of non-leaf modules we want to add observer
prepare_custom_config_dict – customization configuration dictionary for prepare function
# Example of prepare_custom_config_dict: prepare_custom_config_dict = { # user will manually define the corresponding observed # module class which has a from_float class method that converts # float custom module to observed custom module "float_to_observed_custom_module_class": { CustomModule: ObservedCustomModule } }
-
torch.quantization.
prepare_qat
(model, mapping=None, inplace=False)[source]¶ Prepares a copy of the model for quantization calibration or quantization-aware training and converts it to quantized version.
Quantization configuration should be assigned preemptively to individual submodules in .qconfig attribute.
- Parameters
model – input model to be modified in-place
mapping – dictionary that maps float modules to quantized modules to be replaced.
inplace – carry out model transformations in-place, the original module is mutated
-
torch.quantization.
convert
(module, mapping=None, inplace=False, remove_qconfig=True, convert_custom_config_dict=None)[source]¶ Converts submodules in input module to a different module according to mapping by calling from_float method on the target module class. And remove qconfig at the end if remove_qconfig is set to True.
- Parameters
module – prepared and calibrated module
mapping – a dictionary that maps from source module type to target module type, can be overwritten to allow swapping user defined Modules
inplace – carry out model transformations in-place, the original module is mutated
convert_custom_config_dict – custom configuration dictionary for convert function
# Example of convert_custom_config_dict: convert_custom_config_dict = { # user will manually define the corresponding quantized # module class which has a from_observed class method that converts # observed custom module to quantized custom module "observed_to_quantized_custom_module_class": { ObservedCustomModule: QuantizedCustomModule } }
-
class
torch.quantization.
QConfig
[source]¶ Describes how to quantize a layer or a part of the network by providing settings (observer classes) for activations and weights respectively.
Note that QConfig needs to contain observer classes (like MinMaxObserver) or a callable that returns instances on invocation, not the concrete observer instances themselves. Quantization preparation function will instantiate observers multiple times for each of the layers.
Observer classes have usually reasonable default arguments, but they can be overwritten with with_args method (that behaves like functools.partial):
my_qconfig = QConfig(activation=MinMaxObserver.with_args(dtype=torch.qint8), weight=default_observer.with_args(dtype=torch.qint8))
-
class
torch.quantization.
QConfigDynamic
[source]¶ Describes how to dynamically quantize a layer or a part of the network by providing settings (observer classes) for weights.
It’s like QConfig, but for dynamic quantization.
Note that QConfigDynamic needs to contain observer classes (like MinMaxObserver) or a callable that returns instances on invocation, not the concrete observer instances themselves. Quantization function will instantiate observers multiple times for each of the layers.
Observer classes have usually reasonable default arguments, but they can be overwritten with with_args method (that behaves like functools.partial):
my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8))
Preparing model for quantization¶
-
torch.quantization.
fuse_modules
(model, modules_to_fuse, inplace=False, fuser_func=<function fuse_known_modules>, fuse_custom_config_dict=None)[source]¶ Fuses a list of modules into a single module
Fuses only the following sequence of modules: conv, bn conv, bn, relu conv, relu linear, relu bn, relu All other sequences are left unchanged. For these sequences, replaces the first item in the list with the fused module, replacing the rest of the modules with identity.
- Parameters
model – Model containing the modules to be fused
modules_to_fuse – list of list of module names to fuse. Can also be a list of strings if there is only a single list of modules to fuse.
inplace – bool specifying if fusion happens in place on the model, by default a new model is returned
fuser_func – Function that takes in a list of modules and outputs a list of fused modules of the same length. For example, fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()] Defaults to torch.quantization.fuse_known_modules
fuse_custom_config_dict – custom configuration for fusion
# Example of fuse_custom_config_dict fuse_custom_config_dict = { # Additional fuser_method mapping "additional_fuser_method_mapping": { (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn }, }
- Returns
model with fused modules. A new copy is created if inplace=True.
Examples:
>>> m = myModel() >>> # m is a module containing the sub-modules below >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']] >>> fused_m = torch.quantization.fuse_modules(m, modules_to_fuse) >>> output = fused_m(input) >>> m = myModel() >>> # Alternately provide a single list of modules to fuse >>> modules_to_fuse = ['conv1', 'bn1', 'relu1'] >>> fused_m = torch.quantization.fuse_modules(m, modules_to_fuse) >>> output = fused_m(input)
-
class
torch.quantization.
QuantStub
(qconfig=None)[source]¶ Quantize stub module, before calibration, this is same as an observer, it will be swapped as nnq.Quantize in convert.
- Parameters
qconfig – quantization configuration for the tensor, if qconfig is not provided, we will get qconfig from parent modules
-
class
torch.quantization.
DeQuantStub
[source]¶ Dequantize stub module, before calibration, this is same as identity, this will be swapped as nnq.DeQuantize in convert.
-
class
torch.quantization.
QuantWrapper
(module)[source]¶ A wrapper class that wraps the input module, adds QuantStub and DeQuantStub and surround the call to module with call to quant and dequant modules.
This is used by the quantization utility functions to add the quant and dequant modules, before convert function QuantStub will just be observer, it observes the input tensor, after convert, QuantStub will be swapped to nnq.Quantize which does actual quantization. Similarly for DeQuantStub.
-
torch.quantization.
add_quant_dequant
(module)[source]¶ Wrap the leaf child module in QuantWrapper if it has a valid qconfig Note that this function will modify the children of module inplace and it can return a new module which wraps the input module as well.
- Parameters
module – input module with qconfig attributes for all the leaf modules
we want to quantize (that) –
- Returns
Either the inplace modified module with submodules wrapped in QuantWrapper based on qconfig or a new QuantWrapper module which wraps the input module, the latter case only happens when the input module is a leaf module and we want to quantize it.
Utility functions¶
-
torch.quantization.
add_observer_
(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None)[source]¶ Add observer for the leaf child of the module.
This function insert observer module to all leaf child module that has a valid qconfig attribute.
- Parameters
module – input module with qconfig attributes for all the leaf modules that we want to quantize
device – parent device, if any
non_leaf_module_list – list of non-leaf modules we want to add observer
- Returns
None, module is modified inplace with added observer modules and forward_hooks
-
torch.quantization.
swap_module
(mod, mapping, custom_module_class_mapping)[source]¶ Swaps the module if it has a quantized counterpart and it has an observer attached.
- Parameters
mod – input module
mapping – a dictionary that maps from nn module to nnq module
- Returns
The corresponding quantized module of mod
-
torch.quantization.
propagate_qconfig_
(module, qconfig_dict=None, allow_list=None)[source]¶ Propagate qconfig through the module hierarchy and assign qconfig attribute on each leaf module
- Parameters
module – input module
qconfig_dict – dictionary that maps from name or type of submodule to quantization configuration, qconfig applies to all submodules of a given module unless qconfig for the submodules are specified (when the submodule already has qconfig attribute)
- Returns
None, module is modified inplace with qconfig attached
Observers¶
-
class
torch.quantization.
ObserverBase
(dtype)[source]¶ Base observer Module. Any observer implementation should derive from this class.
Concrete observers should follow the same API. In forward, they will update the statistics of the observed Tensor. And they should provide a calculate_qparams function that computes the quantization parameters given the collected statistics.
- Parameters
dtype – Quantized data type
-
classmethod
with_args
(**kwargs)¶ Wrapper that allows creation of class factories.
This can be useful when there is a need to create classes with the same constructor arguments, but different instances.
Example:
>>> Foo.with_args = classmethod(_with_args) >>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42) >>> foo_instance1 = foo_builder() >>> foo_instance2 = foo_builder() >>> id(foo_instance1) == id(foo_instance2) False
-
class
torch.quantization.
MinMaxObserver
(dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=None, quant_max=None, factory_kwargs=None)[source]¶ Observer module for computing the quantization parameters based on the running min and max values.
This observer uses the tensor min/max statistics to compute the quantization parameters. The module records the running minimum and maximum of incoming tensors, and uses this statistic to compute the quantization parameters.
- Parameters
dtype – Quantized data type
qscheme – Quantization scheme to be used
reduce_range – Reduces the range of the quantized data type by 1 bit
quant_min – Minimum quantization value. If unspecified, it will follow the 8-bit setup.
quant_max – Maximum quantization value. If unspecified, it will follow the 8-bit setup.
Given running min/max as and , scale and zero point are computed as:
The running minimum/maximum is computed as:
where is the observed tensor.
The scale and zero point are then computed as:
where and are the minimum and maximum of the quantized data type.
Warning
Only works with
torch.per_tensor_symmetric
quantization schemeWarning
dtype
can only taketorch.qint8
ortorch.quint8
.Note
If the running minimum equals to the running maximum, the scale and zero_point are set to 1.0 and 0.
-
class
torch.quantization.
MovingAverageMinMaxObserver
(averaging_constant=0.01, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=None, quant_max=None, **kwargs)[source]¶ Observer module for computing the quantization parameters based on the moving average of the min and max values.
This observer computes the quantization parameters based on the moving averages of minimums and maximums of the incoming tensors. The module records the average minimum and maximum of incoming tensors, and uses this statistic to compute the quantization parameters.
- Parameters
averaging_constant – Averaging constant for min/max.
dtype – Quantized data type
qscheme – Quantization scheme to be used
reduce_range – Reduces the range of the quantized data type by 1 bit
quant_min – Minimum quantization value. If unspecified, it will follow the 8-bit setup.
quant_max – Maximum quantization value. If unspecified, it will follow the 8-bit setup.
The moving average min/max is computed as follows
where is the running average min/max, is is the incoming tensor, and is the
averaging_constant
.The scale and zero point are then computed as in
MinMaxObserver
.Note
Only works with
torch.per_tensor_affine
quantization scheme.Note
If the running minimum equals to the running maximum, the scale and zero_point are set to 1.0 and 0.
-
class
torch.quantization.
PerChannelMinMaxObserver
(ch_axis=0, dtype=torch.quint8, qscheme=torch.per_channel_affine, reduce_range=False, quant_min=None, quant_max=None, factory_kwargs=None)[source]¶ Observer module for computing the quantization parameters based on the running per channel min and max values.
This observer uses the tensor min/max statistics to compute the per channel quantization parameters. The module records the running minimum and maximum of incoming tensors, and uses this statistic to compute the quantization parameters.
- Parameters
ch_axis – Channel axis
dtype – Quantized data type
qscheme – Quantization scheme to be used
reduce_range – Reduces the range of the quantized data type by 1 bit
quant_min – Minimum quantization value. If unspecified, it will follow the 8-bit setup.
quant_max – Maximum quantization value. If unspecified, it will follow the 8-bit setup.
The quantization parameters are computed the same way as in
MinMaxObserver
, with the difference that the running min/max values are stored per channel. Scales and zero points are thus computed per channel as well.Note
If the running minimum equals to the running maximum, the scales and zero_points are set to 1.0 and 0.
-
class
torch.quantization.
MovingAveragePerChannelMinMaxObserver
(averaging_constant=0.01, ch_axis=0, dtype=torch.quint8, qscheme=torch.per_channel_affine, reduce_range=False, quant_min=None, quant_max=None, **kwargs)[source]¶ Observer module for computing the quantization parameters based on the running per channel min and max values.
This observer uses the tensor min/max statistics to compute the per channel quantization parameters. The module records the running minimum and maximum of incoming tensors, and uses this statistic to compute the quantization parameters.
- Parameters
averaging_constant – Averaging constant for min/max.
ch_axis – Channel axis
dtype – Quantized data type
qscheme – Quantization scheme to be used
reduce_range – Reduces the range of the quantized data type by 1 bit
quant_min – Minimum quantization value. If unspecified, it will follow the 8-bit setup.
quant_max – Maximum quantization value. If unspecified, it will follow the 8-bit setup.
The quantization parameters are computed the same way as in
MovingAverageMinMaxObserver
, with the difference that the running min/max values are stored per channel. Scales and zero points are thus computed per channel as well.Note
If the running minimum equals to the running maximum, the scales and zero_points are set to 1.0 and 0.
-
class
torch.quantization.
HistogramObserver
(bins=2048, upsample_rate=128, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, factory_kwargs=None)[source]¶ The module records the running histogram of tensor values along with min/max values.
calculate_qparams
will calculate scale and zero_point.- Parameters
bins – Number of bins to use for the histogram
upsample_rate – Factor by which the histograms are upsampled, this is used to interpolate histograms with varying ranges across observations
dtype – Quantized data type
qscheme – Quantization scheme to be used
reduce_range – Reduces the range of the quantized data type by 1 bit
The scale and zero point are computed as follows:
- Create the histogram of the incoming inputs.
The histogram is computed continuously, and the ranges per bin change with every new tensor observed.
- Search the distribution in the histogram for optimal min/max values.
The search for the min/max values ensures the minimization of the quantization error with respect to the floating point model.
- Compute the scale and zero point the same way as in the
-
class
torch.quantization.
FakeQuantize
(observer=<class 'torch.quantization.observer.MovingAverageMinMaxObserver'>, quant_min=0, quant_max=255, **observer_kwargs)[source]¶ Simulate the quantize and dequantize operations in training time. The output of this module is given by
x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
scale
defines the scale factor used for quantization.zero_point
specifies the quantized value to which 0 in floating point maps toquant_min
specifies the minimum allowable quantized value.quant_max
specifies the maximum allowable quantized value.fake_quant_enable
controls the application of fake quantization on tensors, note that statistics can still be updated.observer_enable
controls statistics collection on tensorsdtype
specifies the quantized dtype that is being emulated with fake-quantization,allowable values are torch.qint8 and torch.quint8. The values of quant_min and quant_max should be chosen to be consistent with the dtype
- Parameters
- Variables
~FakeQuantize.observer (Module) – User provided module that collects statistics on the input tensor and provides a method to calculate scale and zero-point.
-
class
torch.quantization.
NoopObserver
(dtype=torch.float16, custom_op_name='')[source]¶ Observer that doesn’t do anything and just passes its configuration to the quantized module’s
.from_float()
.Primarily used for quantization to float16 which doesn’t require determining ranges.
- Parameters
dtype – Quantized data type
custom_op_name – (temporary) specify this observer for an operator that doesn’t require any observation (Can be used in Graph Mode Passes for special case ops).
Debugging utilities¶
-
torch.quantization.
get_observer_dict
(mod, target_dict, prefix='')[source]¶ Traverse the modules and save all observers into dict. This is mainly used for quantization accuracy debug :param mod: the top module we want to save all observers :param prefix: the prefix for the current module :param target_dict: the dictionary used to save all the observers
-
class
torch.quantization.
RecordingObserver
(**kwargs)[source]¶ The module is mainly for debug and records the tensor values during runtime.
- Parameters
dtype – Quantized data type
qscheme – Quantization scheme to be used
reduce_range – Reduces the range of the quantized data type by 1 bit