Shortcuts

Source code for torch.quantization.fake_quantize

from __future__ import absolute_import, division, print_function, unicode_literals
import torch
from torch.nn import Module
from .observer import MovingAverageMinMaxObserver, HistogramObserver, MovingAveragePerChannelMinMaxObserver, _with_args

[docs]class FakeQuantize(Module): r""" Simulate the quantize and dequantize operations in training time. The output of this module is given by x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale * :attr:`scale` defines the scale factor used for quantization. * :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to * :attr:`quant_min` specifies the minimum allowable quantized value. * :attr:`quant_max` specifies the maximum allowable quantized value. * :attr:`fake_quant_enable` controls the application of fake quantization on tensors, note that statistics can still be updated. * :attr:`observer_enable` controls statistics collection on tensors * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization, allowable values are torch.qint8 and torch.quint8. The values of quant_min and quant_max should be chosen to be consistent with the dtype Args: observer (module): Module for observing statistics on input tensors and calculating scale and zero-point. quant_min (int): The minimum allowable quantized value. quant_max (int): The maximum allowable quantized value. observer_kwargs (optional): Arguments for the observer module Attributes: observer (Module): User provided module that collects statistics on the input tensor and provides a method to calculate scale and zero-point. """ def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs): super(FakeQuantize, self).__init__() assert quant_min <= quant_max, \ 'quant_min must be less than or equal to quant_max' self.quant_min = quant_min self.quant_max = quant_max self.fake_quant_enabled = True self.observer_enabled = True self.observer = observer(**observer_kwargs) assert torch.iinfo(self.observer.dtype).min <= quant_min, 'quant_min out of bound' assert quant_max <= torch.iinfo(self.observer.dtype).max, 'quant_max out of bound' self.scale = None self.zero_point = None self.dtype = self.observer.dtype self.qscheme = self.observer.qscheme self.ch_axis = self.observer.ch_axis if hasattr(self.observer, 'ch_axis') else 0 def enable_fake_quant(self, enabled=True): self.fake_quant_enabled = enabled return self def disable_fake_quant(self): return self.enable_fake_quant(False) def enable_observer(self, enabled=True): self.observer_enabled = enabled return self def disable_observer(self): return self.enable_observer(False) def calculate_qparams(self): return self.observer.calculate_qparams() def forward(self, X): if self.observer_enabled: self.observer(X.detach()) self.scale, self.zero_point = self.calculate_qparams() if self.fake_quant_enabled: if self.qscheme == torch.per_channel_symmetric or self.qscheme == torch.per_channel_affine: X = torch.fake_quantize_per_channel_affine(X, self.scale, self.zero_point, self.ch_axis, self.quant_min, self.quant_max) else: X = torch.fake_quantize_per_tensor_affine(X, float(self.scale), int(self.zero_point), self.quant_min, self.quant_max) return X with_args = classmethod(_with_args) def extra_repr(self): return 'fake_quant_enabled={}, observer_enabled={},\ scale={}, zero_point={}'.format( self.fake_quant_enabled, self.observer_enabled, self.scale, self.zero_point) def _save_to_state_dict(self, destination, prefix, keep_vars): # We cannot currently register scalar values as buffers, so need to manually # specify serialization here. super(FakeQuantize, self)._save_to_state_dict(destination, prefix, keep_vars) destination[prefix + 'scale'] = self.scale destination[prefix + 'zero_point'] = self.zero_point def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): self.scale = state_dict.pop(prefix + 'scale') self.zero_point = state_dict.pop(prefix + 'zero_point') super(FakeQuantize, self)._load_from_state_dict(state_dict, prefix, local_metadata, False, missing_keys, unexpected_keys, error_msgs)
default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True) default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False) default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_channel_symmetric, reduce_range=False, ch_axis=0) default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver, quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True) def disable_fake_quant(mod): if type(mod) == FakeQuantize: mod.disable_fake_quant() def enable_fake_quant(mod): if type(mod) == FakeQuantize: mod.enable_fake_quant() def disable_observer(mod): if type(mod) == FakeQuantize: mod.disable_observer() def enable_observer(mod): if type(mod) == FakeQuantize: mod.enable_observer()

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources