FX 量化快速入门#

参考:

  1. 量化实践

  2. fx graph 模式 POST TRAINING STATIC QUANTIZATION

本教程介绍基于 torch.fx 在 graph 模式下进行训练后静态量化的步骤。FX Graph 模式量化的优点:可以在模型上完全自动地执行量化,尽管可能需要一些努力使模型与 FX Graph 模式量化兼容(象征性地用 torch.fx 跟踪),将有单独的教程来展示如何使我们想量化的模型的一部分与 FX Graph 模式量化兼容。也有 FX Graph 模式后训练动态量化 教程。FX Graph 模式 API 如下所示:

import torch
from torch import nn, fx
from torch.quantization import get_default_qconfig
from torch.quantization.quantize_fx import prepare_fx, convert_fx

class M(nn.Module):
    def forward(self, x):
        return x

float_model = M()
float_model.eval()
qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}
def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)

valset = []
prepared_model = prepare_fx(float_model, qconfig_dict)  # 融合模块并插入观测器
calibrate(prepared_model, valset)  # 在代表数据上运行校准
quantized_model = convert_fx(prepared_model)  # 转化校准后的模型为量化模型
/media/pc/data/4tb/lxw/libs/anaconda3/envs/tvmx/lib/python3.10/site-packages/torch/ao/quantization/observer.py:176: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  warnings.warn(

FX Graph 模式量化的动机#

目前 PyTorch 存在 eager 模式量化:Static Quantization with Eager Mode in PyTorch

可以看到,该过程涉及到多个手动步骤,包括:

  • 显式地 quantize 和 dequantize activations,当浮点和量化运算混合在模型中时,这是非常耗时的。

  • 显式融合模块,这需要手动识别卷积序列、 batch norms 以及 relus 和其他融合模式。

  • PyTorch 张量运算需要特殊处理(如 addconcat 等)。

  • 函数式没有 first class 支持(functional.conv2dfunctional.linear 不会被量化)

这些需要的修改大多来自于 Eager 模式量化的潜在限制。Eager 模式在模块级工作,因为它不能检查实际运行的代码(在 forward 函数中),量化是通过模块交换实现的,不知道在 Eager 模式下 forward 函数中模块是如何使用的。因此,它需要用户手动插入 QuantStubDeQuantStub,以标记他们想要 quantize 或 dequantize 的点。在图模式中,可以检查在 forward 函数中执行的实际代码(例如 aten 函数调用),量化是通过模块和 graph 操作实现的。由于图模式对运行的代码具有完全的可见性,能够自动地找出要融合哪些模块,在哪里插入 observer 调用,quantize/dequantize 函数等,能够自动化整个量化过程。

FX Graph 模式量化的优点是:

  • 简化量化流程,最小化手动步骤

  • 开启了进行更高级别优化的可能性,如自动精度选择(automatic precision selection)

定义辅助函数和 Prepare Dataset#

首先进行必要的导入,定义一些辅助函数并准备数据。这些步骤与 PyTorch 中 使用 Eager 模式的静态量化 相同。

要使用整个 ImageNet 数据集运行本教程中的代码,首先按照 ImageNet Data 中的说明下载 ImageNet。将下载的文件解压缩到 data_path 文件夹中。

下载 torchvision resnet18 模型 并将其重命名为 models/resnet18_pretrained_float.pth

import numpy as np
from matplotlib import pyplot as plt
2/7
0.2857142857142857
torch.optim.lr_scheduler.StepLR??
Init signature:
torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size,
    gamma=0.1,
    last_epoch=-1,
    verbose=False,
)
Source:        
class StepLR(_LRScheduler):
    """Decays the learning rate of each parameter group by gamma every
    step_size epochs. Notice that such decay can happen simultaneously with
    other changes to the learning rate from outside this scheduler. When
    last_epoch=-1, sets initial lr as lr.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        step_size (int): Period of learning rate decay.
        gamma (float): Multiplicative factor of learning rate decay.
            Default: 0.1.
        last_epoch (int): The index of last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    Example:
        >>> # Assuming optimizer uses lr = 0.05 for all groups
        >>> # lr = 0.05     if epoch < 30
        >>> # lr = 0.005    if 30 <= epoch < 60
        >>> # lr = 0.0005   if 60 <= epoch < 90
        >>> # ...
        >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
        >>> for epoch in range(100):
        >>>     train(...)
        >>>     validate(...)
        >>>     scheduler.step()
    """

    def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False):
        self.step_size = step_size
        self.gamma = gamma
        super(StepLR, self).__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
            return [group['lr'] for group in self.optimizer.param_groups]
        return [group['lr'] * self.gamma
                for group in self.optimizer.param_groups]

    def _get_closed_form_lr(self):
        return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
                for base_lr in self.base_lrs]
File:           /media/pc/data/4tb/lxw/libs/anaconda3/envs/tvmx/lib/python3.10/site-packages/torch/optim/lr_scheduler.py
Type:           type
Subclasses:     
s = 0.857142
x = 0.857142
L = []
for k in range(100):
    x = x * s
    L.append(x)
y = np.array(L)
plt.plot(y)
plt.show()
../../_images/1e3adaef1e0792c2504c00fe3d4b86f124ea7c551cd6bf3603b8fc1852c83df9.png
5e-4
0.0005
from torch_book.data import ImageNet


root = "/media/pc/data/4tb/lxw/datasets/ILSVRC"
saved_model_dir = 'models/'

dataset = ImageNet(root)
trainset = dataset.loader(batch_size=30, split="train")
valset = dataset.loader(batch_size=50, split="val")
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/quant/fx.ipynb Cell 5 in <cell line: 8>()
      <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/quant/fx.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=4'>5</a> saved_model_dir = 'models/'
      <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/quant/fx.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=6'>7</a> dataset = ImageNet(root)
----> <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/quant/fx.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=7'>8</a> trainset = dataset.loader(batch_size=30, split="train")
      <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/quant/fx.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=8'>9</a> valset = dataset.loader(batch_size=50, split="val")

AttributeError: 'ImageNet' object has no attribute 'loader'
import copy
from torchvision import models

model_name = "resnet18"
float_model = getattr(models, model_name)(pretrained=True)
float_model.eval()

# deepcopy the model since we need to keep the original model around
model_to_quantize = copy.deepcopy(float_model)

评估模式的模型#

对于训练后量化,需要将模型设置为评估模式。

model_to_quantize.eval();

使用 qconfig_dict 指定如何量化模型#

qconfig_dict = {"" : default_qconfig}

使用与 Eager 模式量化中相同的 qconfig, qconfig 只是用于激活和权重的 observers 的命名元组。qconfig_dict 是具有以下配置的字典:

qconfig = {
    " : qconfig_global,
    "sub" : qconfig_sub,
    "sub.fc" : qconfig_fc,
    "sub.conv": None
}
qconfig_dict = {
    # qconfig? means either a valid qconfig or None
    # optional, global config
    "": qconfig?,
    # optional, used for module and function types
    # could also be split into module_types and function_types if we prefer
    "object_type": [
        (torch.nn.Conv2d, qconfig?),
        (torch.nn.functional.add, qconfig?),
        ...,
    ],
    # optional, used for module names
    "module_name": [
        ("foo.bar", qconfig?)
        ...,
    ],
    # optional, matched in order, first match takes precedence
    "module_name_regex": [
        ("foo.*bar.*conv[0-9]+", qconfig?)
        ...,
    ],
    # priority (in increasing order): global, object_type, module_name_regex, module_name
    # qconfig == None means fusion and quantization should be skipped for anything
    # matching the rule

    # **api subject to change**
    # optional: specify the path for standalone modules
    # These modules are symbolically traced and quantized as one unit
    # so that the call to the submodule appears as one call_module
    # node in the forward graph of the GraphModule
    "standalone_module_name": [
        "submodule.standalone"
    ],
    "standalone_module_class": [
        StandaloneModuleClass
    ]
}

可以在 qconfig 文件 中找到与 qconfig 相关的实用函数:

from torch.quantization import get_default_qconfig, quantize_jit

qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}

为静态后训练量化模型做准备#

import warnings
from torch.quantization.quantize_fx import prepare_fx

warnings.filterwarnings('ignore')

prepared_model = prepare_fx(model_to_quantize, qconfig_dict)

prepare_fx 将 BatchNorm 模块折叠到前面的 Conv2d 模块中,并在模型中的适当位置插入 observers。

print(prepared_model.graph)

校准#

将 observers 插入模型后,运行校准函数。校准的目的就是通过一些样本运行代表性的工作负载(例如样本的训练数据集)以便 observers 在模型中能够观测到张量的统计数据,以后使用这些信息来计算量化参数。

import torch

def calibrate(model, data_loader, samples=500):
    model.eval()
    with torch.no_grad():
        k = 0
        for image, _ in data_loader:
            if k > samples:
                break
            model(image)
            k += len(image)

calibrate(prepared_model, trainset)  # run calibration on sample data

将模型转换为量化模型#

convert_fx 采用校准模型并产生量化模型。

from torch.quantization.quantize_fx import convert_fx

quantized_model = convert_fx(prepared_model)
print(quantized_model)

评估#

现在可以打印量化模型的大小和精度。

from torch_book.contrib.helper import evaluate, print_size_of_model
from torch import nn


criterion = nn.CrossEntropyLoss()

print("Size of model before quantization")
print_size_of_model(float_model)
print("Size of model after quantization")
print_size_of_model(quantized_model)
top1, top5 = evaluate(quantized_model, criterion, valset)
print(f"[before serilaization] Evaluation accuracy on test dataset: {top1.avg:2.2f}, {top5.avg:2.2f}")
fx_graph_mode_model_file_path = saved_model_dir + f"{model_name}_fx_graph_mode_quantized.pth"

# this does not run due to some erros loading convrelu module:
# ModuleAttributeError: 'ConvReLU2d' object has no attribute '_modules'
# save the whole model directly
# torch.save(quantized_model, fx_graph_mode_model_file_path)
# loaded_quantized_model = torch.load(fx_graph_mode_model_file_path)

# save with state_dict
# torch.save(quantized_model.state_dict(), fx_graph_mode_model_file_path)
# import copy
# model_to_quantize = copy.deepcopy(float_model)
# prepared_model = prepare_fx(model_to_quantize, {"": qconfig})
# loaded_quantized_model = convert_fx(prepared_model)
# loaded_quantized_model.load_state_dict(torch.load(fx_graph_mode_model_file_path))

# save with script
torch.jit.save(torch.jit.script(quantized_model), fx_graph_mode_model_file_path)
loaded_quantized_model = torch.jit.load(fx_graph_mode_model_file_path)

top1, top5 = evaluate(loaded_quantized_model, criterion, valset)
print(f"[after serialization/deserialization] Evaluation accuracy on test dataset: {top1.avg:2.2f}, {top5.avg:2.2f}")

如果希望获得更好的精度或性能,请尝试更改 qconfig_dict

调试量化模型#

还可以打印量化的 un-quantized conv 的权重来查看区别,首先显式地调用 fuse 来融合模型中的 conv 和 bn:注意,fuse_fx 只在 eval 模式下工作。

from torch.quantization.quantize_fx import fuse_fx

fused = fuse_fx(float_model)

conv1_weight_after_fuse = fused.conv1[0].weight[0]
conv1_weight_after_quant = quantized_model.conv1.weight().dequantize()[0]

print(torch.max(abs(conv1_weight_after_fuse - conv1_weight_after_quant)))

基线浮点模型和 Eager 模式量化的比较#

scripted_float_model_file = "resnet18_scripted.pth"

print("Size of baseline model")
print_size_of_model(float_model)

top1, top5 = evaluate(float_model, criterion, valset)
print("Baseline Float Model Evaluation accuracy: %2.2f, %2.2f"%(top1.avg, top5.avg))
torch.jit.save(torch.jit.script(float_model), saved_model_dir + scripted_float_model_file)

在本节中,将量化模型与 FX Graph 模式的量化模型与在 Eager 模式下量化的模型进行比较。FX Graph 模式和 Eager 模式产生的量化模型非常相似,因此期望精度和 speedup 也很相似。

print("Size of Fx graph mode quantized model")
print_size_of_model(quantized_model)
top1, top5 = evaluate(quantized_model, criterion, valset)
print("FX graph mode quantized model Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))

from torchvision.models.quantization.resnet import resnet18
eager_quantized_model = resnet18(pretrained=True, quantize=True).eval()
print("Size of eager mode quantized model")
eager_quantized_model = torch.jit.script(eager_quantized_model)
print_size_of_model(eager_quantized_model)
top1, top5 = evaluate(eager_quantized_model, criterion, valset)
print("eager mode quantized model Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))
eager_mode_model_file = "resnet18_eager_mode_quantized.pth"
torch.jit.save(eager_quantized_model, saved_model_dir + eager_mode_model_file)

可以看到 FX Graph 模式和 Eager 模式量化模型的模型大小和精度是非常相似的。

在 AIBench 中运行模型(单线程)会得到如下结果:

Scripted Float Model:
Self CPU time total: 192.48ms

Scripted Eager Mode Quantized Model:
Self CPU time total: 50.76ms

Scripted FX Graph Mode Quantized Model:
Self CPU time total: 50.63ms

可以看到,对于 resnet18, FX Graph 模式和 Eager 模式量化模型都比浮点模型获得了相似的速度,大约比浮点模型快 2-4 倍。但是浮点模型上的实际加速可能会因模型、设备、构建、输入批大小、线程等而不同。