Shortcuts

Source code for torch.nn.intrinsic.qat.modules.conv_fused

from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import torch.nn as nn
import torch.nn.intrinsic
import torch.nn.qat as nnqat
import torch.nn.functional as F
from torch.nn import init


[docs]class ConvBn2d(nn.Conv2d): r""" A ConvBn2d module is a module fused from Conv2d and BatchNorm2d, attached with FakeQuantize modules for both output activation and weight, used in quantization aware training. We combined the interface of :class:`torch.nn.Conv2d` and :class:`torch.nn.BatchNorm2d`. Implementation details: https://arxiv.org/pdf/1806.08342.pdf section 3.2.2 Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized to default. Attributes: freeze_bn: observer: fake quant module for output activation, it's called observer to align with post training flow weight_fake_quant: fake quant module for weight """ _FLOAT_MODULE = torch.nn.intrinsic.ConvBn2d def __init__(self, # Conv2d args in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, # bias: None, only support Conv with no bias padding_mode='zeros', # BatchNorm2d args # num_features: out_channels eps=1e-05, momentum=0.1, # affine: True # track_running_stats: True # Args for this module freeze_bn=False, qconfig=None): super(ConvBn2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, False, padding_mode) assert qconfig, 'qconfig must be provided for QAT module' self.qconfig = qconfig self.eps = eps self.momentum = momentum self.freeze_bn = freeze_bn if self.training else True self.num_features = out_channels self.gamma = nn.Parameter(torch.Tensor(out_channels)) self.beta = nn.Parameter(torch.Tensor(out_channels)) self.affine = True self.track_running_stats = True self.register_buffer('running_mean', torch.zeros(out_channels)) self.register_buffer('running_var', torch.ones(out_channels)) self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) self.observer = self.qconfig.activation() self.weight_fake_quant = self.qconfig.weight() self.reset_bn_parameters() def reset_running_stats(self): self.running_mean.zero_() self.running_var.fill_(1) self.num_batches_tracked.zero_() def reset_bn_parameters(self): self.reset_running_stats() init.uniform_(self.gamma) init.zeros_(self.beta) def reset_parameters(self): super(ConvBn2d, self).reset_parameters() # A hack to avoid resetting on undefined parameters if hasattr(self, 'gamma'): self.reset_bn_parameters() def update_bn_stats(self): self.freeze_bn = False return self def freeze_bn_stats(self): self.freeze_bn = True return self def _forward(self, input): # exponential_average_factor is self.momentum set to # (when it is available) only so that if gets updated # in ONNX graph when this node is exported to ONNX. if self.momentum is None: exponential_average_factor = 0.0 else: exponential_average_factor = self.momentum if self.training and not self.freeze_bn and self.track_running_stats: # TODO: if statement only here to tell the jit to skip emitting this when it is None if self.num_batches_tracked is not None: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum # we use running statistics from the previous batch, so this is an # approximation of the approach mentioned in the whitepaper, but we only # need to do one convolution in this case instead of two running_std = torch.sqrt(self.running_var + self.eps) scale_factor = self.gamma / running_std scaled_weight = self.weight * scale_factor.reshape([-1, 1, 1, 1]) conv = self.conv2d_forward(input, self.weight_fake_quant(scaled_weight)) if self.training and not self.freeze_bn: # recovering original conv to get original batch_mean and batch_var conv_orig = conv / scale_factor.reshape([1, -1, 1, 1]) batch_mean = torch.mean(conv_orig, dim=[0, 2, 3]) batch_var = torch.var(conv_orig, dim=[0, 2, 3], unbiased=False) n = float(conv_orig.numel() / conv_orig.size()[1]) unbiased_batch_var = batch_var * (n / (n - 1)) batch_rstd = torch.ones_like(batch_var) / torch.sqrt(batch_var + self.eps) rescale_factor = running_std * batch_rstd conv = conv * rescale_factor.reshape([1, -1, 1, 1]) conv = conv + (self.beta - self.gamma * batch_mean * batch_rstd).reshape([1, -1, 1, 1]) self.running_mean = exponential_average_factor * batch_mean.detach() + \ (1 - exponential_average_factor) * self.running_mean self.running_var = exponential_average_factor * unbiased_batch_var.detach() + \ (1 - exponential_average_factor) * self.running_var else: conv = conv + (self.beta - self.gamma * self.running_mean / running_std).reshape([1, -1, 1, 1]) return conv def extra_repr(self): # TODO(jerryzh): extend return super(ConvBn2d, self).extra_repr() def forward(self, input): return self.observer(self._forward(input))
[docs] @classmethod def from_float(cls, mod, qconfig=None): r"""Create a qat module from a float module or qparams_dict Args: `mod` a float module, either produced by torch.quantization utilities or directly from user """ assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \ cls._FLOAT_MODULE.__name__ if not qconfig: assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' assert mod.qconfig, 'Input float module must have a valid qconfig' qconfig = mod.qconfig conv, bn = mod[0], mod[1] qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, conv.dilation, conv.groups, conv.padding_mode, bn.eps, bn.momentum, False, qconfig) assert qat_convbn.bias is None, 'QAT ConvBn should not have bias' qat_convbn.weight = conv.weight qat_convbn.gamma = bn.weight qat_convbn.beta = bn.bias qat_convbn.running_mean = bn.running_mean qat_convbn.running_var = bn.running_var qat_convbn.num_batches_tracked = bn.num_batches_tracked return qat_convbn
[docs]class ConvBnReLU2d(ConvBn2d): r""" A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU, attached with FakeQuantize modules for both output activation and weight, used in quantization aware training. We combined the interface of :class:`torch.nn.Conv2d` and :class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`. Implementation details: https://arxiv.org/pdf/1806.08342.pdf Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to default. Attributes: observer: fake quant module for output activation, it's called observer to align with post training flow weight_fake_quant: fake quant module for weight """ _FLOAT_MODULE = torch.nn.intrinsic.ConvBnReLU2d def __init__(self, # Conv2d args in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, # bias: None, only support Conv with no bias padding_mode='zeros', # BatchNorm2d args # num_features: out_channels eps=1e-05, momentum=0.1, # affine: True # track_running_stats: True # Args for this module freeze_bn=False, qconfig=None): super(ConvBnReLU2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, padding_mode, eps, momentum, freeze_bn, qconfig) def forward(self, input): return self.observer(F.relu(super(ConvBnReLU2d, self)._forward(input))) @classmethod def from_float(cls, mod, qconfig=None): return super(ConvBnReLU2d, cls).from_float(mod, qconfig)
[docs]class ConvReLU2d(nnqat.Conv2d): r""" A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with FakeQuantize modules for both output activation and weight for quantization aware training. We combined the interface of :class:`~torch.nn.Conv2d` and :class:`~torch.nn.BatchNorm2d`. Attributes: observer: fake quant module for output activation, it's called observer to align with post training flow weight_fake_quant: fake quant module for weight """ _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU2d def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', qconfig=None): super(ConvReLU2d, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, qconfig=qconfig) assert qconfig, 'qconfig must be provided for QAT module' self.qconfig = qconfig self.observer = self.qconfig.activation() self.weight_fake_quant = self.qconfig.weight() def forward(self, input): return self.observer(F.relu(super(ConvReLU2d, self).conv2d_forward(input, self.weight_fake_quant(self.weight)))) @classmethod def from_float(cls, mod, qconfig=None): return super(ConvReLU2d, cls).from_float(mod, qconfig)
def update_bn_stats(mod): if type(mod) in set([ConvBnReLU2d, ConvBn2d]): mod.update_bn_stats() def freeze_bn_stats(mod): if type(mod) in set([ConvBnReLU2d, ConvBn2d]): mod.freeze_bn_stats()

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources