Eager 模式量化#

PTDQ(Eager)#

PTDQ(Post Training Dynamic Quantization)是最简单的量化应用形式,其中权重提前量化,但激活在推理期间动态量化(dynamically quantized)。这用于模型执行时间主要由从内存加载权重而不是计算矩阵乘法支配的情况。对于小批量的 LSTM 和 Transformer 类型模型来说是这样的。

PTDQ 示意图

原始模型(所有的张量和计算都是浮点的):

previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
                 /
linear_weight_fp32

动态量化模型(linear 和 LSTM 权值均是 int8):

previous_layer_fp32 -- linear_int8_w_fp32_inp -- activation_fp32 -- next_layer_fp32
                     /
   linear_weight_int8

示例:

import torch

# 定义浮点模型
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(4, 4)

    def forward(self, x):
        x = self.fc(x)
        return x

# 创建模型实例
model_fp32 = M()
# 创建量化的模型实例
model_int8 = torch.quantization.quantize_dynamic(
    model_fp32,  # 原始模型
    {torch.nn.Linear},  # 一组要动态量化的层
    dtype=torch.qint8)  # 量化权重的 target dtype

# 运行模型
input_fp32 = torch.randn(4, 4, 4, 4)
res = model_int8(input_fp32)

参见

参阅动态量化教程

PTSQ(Eager)#

训练后静态量化(Post Training Static Quantization,简称 PTSQ)需要量化模型的权重和激活。它在可能的情况下将激活融合到前面的层中。它需要用有代表性的数据集校准,以确定激活的最佳量化参数。当内存带宽和计算节省都很重要时,通常使用训练后静态量化,而 CNN 是典型的用例。

警告

在应用训练后静态量化之前,可能需要修改模型。

PTSQ 示意图

原始模型(所有的张量和计算都是浮点的):

previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
                    /
    linear_weight_fp32

静态量化模型(权值和激活均是 int8):

previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8
                    /
  linear_weight_int8

示例:

import torch

# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        x = self.conv(x)
        x = self.relu(x)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x

# create a model instance
model_fp32 = M()

# model must be set to eval mode for static quantization logic to work
model_fp32.eval()

# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'fbgemm' for server inference and
# 'qnnpack' for mobile inference. Other quantization configurations such
# as selecting symmetric or assymetric quantization and MinMax or L2Norm
# calibration techniques can be specified here.
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')

# Fuse the activations to preceding layers, where applicable.
# This needs to be done manually depending on the model architecture.
# Common fusions include `conv + relu` and `conv + batchnorm + relu`
model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [['conv', 'relu']])

# Prepare the model for static quantization. This inserts observers in
# the model that will observe activation tensors during calibration.
model_fp32_prepared = torch.quantization.prepare(model_fp32_fused)

# calibrate the prepared model to determine quantization parameters for activations
# in a real world setting, the calibration would be done with a representative dataset
input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)

# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, and replaces key operators with quantized
# implementations.
model_int8 = torch.quantization.convert(model_fp32_prepared)

# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)

参见

可以前往 PTSQ 教程

QAT(Eager)#

量化感知训练(Quantization Aware Training,简称 QAT)对训练过程中的量化效果进行建模,与其他量化方法相比,可以获得更高的精度。可以对静态、动态或仅权量化进行 QAT。在训练过程中,所有的计算都是用浮点数完成的,fake_quant 模块通过夹紧(clamping)和舍入(rounding)来建模量化的效果,以模拟 INT8 的效果。在模型转换后,量化权重和激活,并在可能的情况下将激活融合到前一层。它通常与 CNN 一起使用,与静态量化相比具有更高的精度。

QAT 示意图

原始模型(所有的张量和计算都是浮点的):

previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
                      /
    linear_weight_fp32

使用 fake_quants 建模,用于在训练期间建模量化数值

previous_layer_fp32 -- fq -- linear_fp32 -- activation_fp32 -- fq -- next_layer_fp32
                           /
   linear_weight_fp32 -- fq

量化模型(权值和激活均是 int8):

previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8
                     /
   linear_weight_int8

示例:

import torch

# define a floating point model where some layers could benefit from QAT
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.bn = torch.nn.BatchNorm2d(1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dequant(x)
        return x

# create a model instance
model_fp32 = M()

# model must be set to eval for fusion to work
model_fp32.eval()

# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'fbgemm' for server inference and
# 'qnnpack' for mobile inference. Other quantization configurations such
# as selecting symmetric or assymetric quantization and MinMax or L2Norm
# calibration techniques can be specified here.
model_fp32.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

# fuse the activations to preceding layers, where applicable
# this needs to be done manually depending on the model architecture
model_fp32_fused = torch.quantization.fuse_modules(model_fp32,
    [['conv', 'bn', 'relu']])

# Prepare the model for QAT. This inserts observers and fake_quants in
# the model needs to be set to train for QAT logic to work
# the model that will observe weight and activation tensors during calibration.
model_fp32_prepared = torch.quantization.prepare_qat(model_fp32_fused.train())

# run the training loop (not shown)
training_loop(model_fp32_prepared)

# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, fuses modules where appropriate,
# and replaces key operators with quantized implementations.
model_fp32_prepared.eval()
model_int8 = torch.quantization.convert(model_fp32_prepared)

# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)

使用预训练的量化模型#

比较非量化 MobileNet v2 模型与其量化版本的大小差异:

import torch
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
from torchvision.models.quantization import mobilenet_v2 as qmobilenet_v2, MobileNet_V2_QuantizedWeights


model_quantized = qmobilenet_v2(weights=MobileNet_V2_QuantizedWeights.DEFAULT, quantize=True)
model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)

import os
import torch

def print_model_size(mdl, temp_path="tmp.pt"):
    torch.save(mdl.state_dict(), temp_path)
    size = os.path.getsize(temp_path)/1e6
    print(f"{size:.2f} MB")
    os.remove(temp_path)

print_model_size(model)
print_model_size(model_quantized)