Shortcuts

torch.nn.intrinsic.qat

This module implements the versions of those fused operations needed for quantization aware training.

ConvBn2d

class torch.nn.intrinsic.qat.ConvBn2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=None, padding_mode='zeros', eps=1e-05, momentum=0.1, freeze_bn=False, qconfig=None)[source]

A ConvBn2d module is a module fused from Conv2d and BatchNorm2d, attached with FakeQuantize modules for weight, used in quantization aware training.

We combined the interface of torch.nn.Conv2d and torch.nn.BatchNorm2d.

Similar to torch.nn.Conv2d, with FakeQuantize modules initialized to default.

Variables
  • ~ConvBn2d.freeze_bn

  • ~ConvBn2d.weight_fake_quant – fake quant module for weight

ConvBnReLU2d

class torch.nn.intrinsic.qat.ConvBnReLU2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=None, padding_mode='zeros', eps=1e-05, momentum=0.1, freeze_bn=False, qconfig=None)[source]

A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU, attached with FakeQuantize modules for weight, used in quantization aware training.

We combined the interface of torch.nn.Conv2d and torch.nn.BatchNorm2d and torch.nn.ReLU.

Similar to torch.nn.Conv2d, with FakeQuantize modules initialized to default.

Variables

~ConvBnReLU2d.weight_fake_quant – fake quant module for weight

ConvReLU2d

class torch.nn.intrinsic.qat.ConvReLU2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', qconfig=None)[source]

A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with FakeQuantize modules for weight for quantization aware training.

We combined the interface of Conv2d and BatchNorm2d.

Variables

~ConvReLU2d.weight_fake_quant – fake quant module for weight

ConvBn3d

class torch.nn.intrinsic.qat.ConvBn3d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=None, padding_mode='zeros', eps=1e-05, momentum=0.1, freeze_bn=False, qconfig=None)[source]

A ConvBn3d module is a module fused from Conv3d and BatchNorm3d, attached with FakeQuantize modules for weight, used in quantization aware training.

We combined the interface of torch.nn.Conv3d and torch.nn.BatchNorm3d.

Similar to torch.nn.Conv3d, with FakeQuantize modules initialized to default.

Variables
  • ~ConvBn3d.freeze_bn

  • ~ConvBn3d.weight_fake_quant – fake quant module for weight

ConvBnReLU3d

class torch.nn.intrinsic.qat.ConvBnReLU3d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=None, padding_mode='zeros', eps=1e-05, momentum=0.1, freeze_bn=False, qconfig=None)[source]

A ConvBnReLU3d module is a module fused from Conv3d, BatchNorm3d and ReLU, attached with FakeQuantize modules for weight, used in quantization aware training.

We combined the interface of torch.nn.Conv3d and torch.nn.BatchNorm3d and torch.nn.ReLU.

Similar to torch.nn.Conv3d, with FakeQuantize modules initialized to default.

Variables

~ConvBnReLU3d.weight_fake_quant – fake quant module for weight

ConvReLU3d

class torch.nn.intrinsic.qat.ConvReLU3d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', qconfig=None)[source]

A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with FakeQuantize modules for weight for quantization aware training.

We combined the interface of Conv3d and BatchNorm3d.

Variables

~ConvReLU3d.weight_fake_quant – fake quant module for weight

LinearReLU

class torch.nn.intrinsic.qat.LinearReLU(in_features, out_features, bias=True, qconfig=None)[source]

A LinearReLU module fused from Linear and ReLU modules, attached with FakeQuantize modules for weight, used in quantization aware training.

We adopt the same interface as torch.nn.Linear.

Similar to torch.nn.intrinsic.LinearReLU, with FakeQuantize modules initialized to default.

Variables

~LinearReLU.weight – fake quant module for weight

Examples:

>>> m = nn.qat.LinearReLU(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources