PyTorch 2 export 训练后量化#

参考:pt2e-PTQ

准备模型和数据集:

import torch
import torch.nn as nn
from torchvision.models.resnet import resnet18, ResNet18_Weights
from imagenet import ImageNet

train_batch_size = 30
eval_batch_size = 50
# data_path = 'data/imagenet'
data_path = "/media/pc/data/lxw/home/data/datasets/ILSVRC"
dataset = ImageNet(data_path)
data_loader = dataset.train_loader(train_batch_size)
data_loader_test = dataset.test_loader(eval_batch_size)
example_inputs = (next(iter(data_loader))[0])
criterion = nn.CrossEntropyLoss()
float_model = resnet18(weights=ResNet18_Weights.DEFAULT)
float_model = float_model.to("cpu")

训练后量化(Post Training Quantization,简称 PTQ),需要将模型设置为评估模式。

model_to_quantize = float_model.eval()

使用 torch.export.export() 导出模型#

# 创建示例输入:形状为(2, 3, 224, 224)的随机张量
example_inputs = (torch.rand(2, 3, 224, 224),)

# 适用于PyTorch 2.6及以上版本
# 导出模型,捕获计算图并获取模块
exported_model = torch.export.export(model_to_quantize, example_inputs).module()

# 适用于PyTorch 2.5及以前版本
# from torch._export import capture_pre_autograd_graph
# exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs)

# 或者使用动态维度进行捕获
# 适用于PyTorch 2.6及以上版本
# 为第一个输入张量的第0维设置动态维度
dynamic_shapes = tuple(
  {0: torch.export.Dim("dim")} if i == 0 else None
  for i in range(len(example_inputs))
)
# 使用动态维度导出模型
exported_model = torch.export.export(model_to_quantize, example_inputs, dynamic_shapes=dynamic_shapes).module()

# 适用于PyTorch 2.5及以前版本
# 动态维度API可能有所不同
# from torch._export import dynamic_dim
# exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs, constraints=[dynamic_dim(example_inputs[0], 0)])

导入后端特定量化器并配置如何量化模型#

以下代码片段描述了如何量化模型:

from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
  get_symmetric_quantization_config,
  XNNPACKQuantizer,
)
quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config())
<executorch.backends.xnnpack.quantizer.xnnpack_quantizer.XNNPACKQuantizer at 0x7fb9a82f80b0>

Quantizer 是后端特定的,每个 Quantizer 都会提供自己的方式来允许用户配置他们的模型。例如,这里支持 XNNPackQuantizer 的不同配置 API:

# 设置全局量化配置
# qconfig_opt 是一个可选的量化配置对象
quantizer.set_global(qconfig_opt) 
    # 为 Conv2d 类型的模块设置量化配置
    # 可以针对整个模块类型进行设置
    .set_object_type(torch.nn.Conv2d, qconfig_opt) 
    # 为线性函数操作设置量化配置
    # 也可以针对 PyTorch 函数式操作进行设置
    .set_object_type(torch.nn.functional.linear, qconfig_opt)
    # 为特定名称的模块设置量化配置
    # 这里 "foo.bar" 表示模块的路径名称
    .set_module_name("foo.bar", qconfig_opt)

参见

了解如何编写新的 Quantizer

准备模型进行训练后量化#

prepare_pt2eBatchNorm 个算子合并到前 Conv2d 个算子中,并在模型中适当位置插入观测者。

# Set up warnings
import warnings
warnings.filterwarnings(
    action='ignore',
    category=DeprecationWarning,
    module=r'.*'
)
warnings.filterwarnings(
    action='ignore',
    module=r'torch.fx.graph'
)
warnings.filterwarnings(
    action='default',
    module=r'torchao.quantization.pt2e'
)
from torchao.quantization.pt2e.quantize_pt2e import (
  prepare_pt2e,
  convert_pt2e,
)
prepared_model = prepare_pt2e(exported_model, quantizer)
print(prepared_model.graph)

Hide code cell output

graph():
    %conv1_weight : [num_users=1] = get_attr[target=conv1.weight]
    %activation_post_process_1 : [num_users=1] = call_module[target=activation_post_process_1](args = (%conv1_weight,), kwargs = {})
    %layer1_0_conv1_weight : [num_users=1] = get_attr[target=layer1.0.conv1.weight]
    %activation_post_process_4 : [num_users=1] = call_module[target=activation_post_process_4](args = (%layer1_0_conv1_weight,), kwargs = {})
    %layer1_0_conv2_weight : [num_users=1] = get_attr[target=layer1.0.conv2.weight]
    %activation_post_process_6 : [num_users=1] = call_module[target=activation_post_process_6](args = (%layer1_0_conv2_weight,), kwargs = {})
    %layer1_1_conv1_weight : [num_users=1] = get_attr[target=layer1.1.conv1.weight]
    %activation_post_process_9 : [num_users=1] = call_module[target=activation_post_process_9](args = (%layer1_1_conv1_weight,), kwargs = {})
    %layer1_1_conv2_weight : [num_users=1] = get_attr[target=layer1.1.conv2.weight]
    %activation_post_process_11 : [num_users=1] = call_module[target=activation_post_process_11](args = (%layer1_1_conv2_weight,), kwargs = {})
    %layer2_0_conv1_weight : [num_users=1] = get_attr[target=layer2.0.conv1.weight]
    %activation_post_process_14 : [num_users=1] = call_module[target=activation_post_process_14](args = (%layer2_0_conv1_weight,), kwargs = {})
    %layer2_0_conv2_weight : [num_users=1] = get_attr[target=layer2.0.conv2.weight]
    %activation_post_process_16 : [num_users=1] = call_module[target=activation_post_process_16](args = (%layer2_0_conv2_weight,), kwargs = {})
    %layer2_0_downsample_0_weight : [num_users=1] = get_attr[target=layer2.0.downsample.0.weight]
    %activation_post_process_18 : [num_users=1] = call_module[target=activation_post_process_18](args = (%layer2_0_downsample_0_weight,), kwargs = {})
    %layer2_1_conv1_weight : [num_users=1] = get_attr[target=layer2.1.conv1.weight]
    %activation_post_process_21 : [num_users=1] = call_module[target=activation_post_process_21](args = (%layer2_1_conv1_weight,), kwargs = {})
    %layer2_1_conv2_weight : [num_users=1] = get_attr[target=layer2.1.conv2.weight]
    %activation_post_process_23 : [num_users=1] = call_module[target=activation_post_process_23](args = (%layer2_1_conv2_weight,), kwargs = {})
    %layer3_0_conv1_weight : [num_users=1] = get_attr[target=layer3.0.conv1.weight]
    %activation_post_process_26 : [num_users=1] = call_module[target=activation_post_process_26](args = (%layer3_0_conv1_weight,), kwargs = {})
    %layer3_0_conv2_weight : [num_users=1] = get_attr[target=layer3.0.conv2.weight]
    %activation_post_process_28 : [num_users=1] = call_module[target=activation_post_process_28](args = (%layer3_0_conv2_weight,), kwargs = {})
    %layer3_0_downsample_0_weight : [num_users=1] = get_attr[target=layer3.0.downsample.0.weight]
    %activation_post_process_30 : [num_users=1] = call_module[target=activation_post_process_30](args = (%layer3_0_downsample_0_weight,), kwargs = {})
    %layer3_1_conv1_weight : [num_users=1] = get_attr[target=layer3.1.conv1.weight]
    %activation_post_process_33 : [num_users=1] = call_module[target=activation_post_process_33](args = (%layer3_1_conv1_weight,), kwargs = {})
    %layer3_1_conv2_weight : [num_users=1] = get_attr[target=layer3.1.conv2.weight]
    %activation_post_process_35 : [num_users=1] = call_module[target=activation_post_process_35](args = (%layer3_1_conv2_weight,), kwargs = {})
    %layer4_0_conv1_weight : [num_users=1] = get_attr[target=layer4.0.conv1.weight]
    %activation_post_process_38 : [num_users=1] = call_module[target=activation_post_process_38](args = (%layer4_0_conv1_weight,), kwargs = {})
    %layer4_0_conv2_weight : [num_users=1] = get_attr[target=layer4.0.conv2.weight]
    %activation_post_process_40 : [num_users=1] = call_module[target=activation_post_process_40](args = (%layer4_0_conv2_weight,), kwargs = {})
    %layer4_0_downsample_0_weight : [num_users=1] = get_attr[target=layer4.0.downsample.0.weight]
    %activation_post_process_42 : [num_users=1] = call_module[target=activation_post_process_42](args = (%layer4_0_downsample_0_weight,), kwargs = {})
    %layer4_1_conv1_weight : [num_users=1] = get_attr[target=layer4.1.conv1.weight]
    %activation_post_process_45 : [num_users=1] = call_module[target=activation_post_process_45](args = (%layer4_1_conv1_weight,), kwargs = {})
    %layer4_1_conv2_weight : [num_users=1] = get_attr[target=layer4.1.conv2.weight]
    %activation_post_process_47 : [num_users=1] = call_module[target=activation_post_process_47](args = (%layer4_1_conv2_weight,), kwargs = {})
    %fc_weight : [num_users=1] = get_attr[target=fc.weight]
    %activation_post_process_52 : [num_users=1] = call_module[target=activation_post_process_52](args = (%fc_weight,), kwargs = {})
    %fc_bias : [num_users=1] = get_attr[target=fc.bias]
    %x : [num_users=1] = placeholder[target=x]
    %activation_post_process_0 : [num_users=1] = call_module[target=activation_post_process_0](args = (%x,), kwargs = {})
    %conv1_weight_bias : [num_users=1] = get_attr[target=conv1.weight_bias]
    %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%activation_post_process_0, %activation_post_process_1, %conv1_weight_bias, [2, 2], [3, 3]), kwargs = {})
    %relu_ : [num_users=1] = call_function[target=torch.ops.aten.relu_.default](args = (%conv2d,), kwargs = {})
    %activation_post_process_2 : [num_users=1] = call_module[target=activation_post_process_2](args = (%relu_,), kwargs = {})
    %max_pool2d : [num_users=1] = call_function[target=torch.ops.aten.max_pool2d.default](args = (%activation_post_process_2, [3, 3], [2, 2], [1, 1]), kwargs = {})
    %activation_post_process_3 : [num_users=2] = call_module[target=activation_post_process_3](args = (%max_pool2d,), kwargs = {})
    %layer1_0_conv1_weight_bias : [num_users=1] = get_attr[target=layer1.0.conv1.weight_bias]
    %conv2d_1 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%activation_post_process_3, %activation_post_process_4, %layer1_0_conv1_weight_bias, [1, 1], [1, 1]), kwargs = {})
    %relu__1 : [num_users=1] = call_function[target=torch.ops.aten.relu_.default](args = (%conv2d_1,), kwargs = {})
    %activation_post_process_5 : [num_users=1] = call_module[target=activation_post_process_5](args = (%relu__1,), kwargs = {})
    %layer1_0_conv2_weight_bias : [num_users=1] = get_attr[target=layer1.0.conv2.weight_bias]
    %conv2d_2 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%activation_post_process_5, %activation_post_process_6, %layer1_0_conv2_weight_bias, [1, 1], [1, 1]), kwargs = {})
    %activation_post_process_7 : [num_users=1] = call_module[target=activation_post_process_7](args = (%conv2d_2,), kwargs = {})
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%activation_post_process_7, %activation_post_process_3), kwargs = {})
    %relu__2 : [num_users=1] = call_function[target=torch.ops.aten.relu_.default](args = (%add_,), kwargs = {})
    %activation_post_process_8 : [num_users=2] = call_module[target=activation_post_process_8](args = (%relu__2,), kwargs = {})
    %layer1_1_conv1_weight_bias : [num_users=1] = get_attr[target=layer1.1.conv1.weight_bias]
    %conv2d_3 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%activation_post_process_8, %activation_post_process_9, %layer1_1_conv1_weight_bias, [1, 1], [1, 1]), kwargs = {})
    %relu__3 : [num_users=1] = call_function[target=torch.ops.aten.relu_.default](args = (%conv2d_3,), kwargs = {})
    %activation_post_process_10 : [num_users=1] = call_module[target=activation_post_process_10](args = (%relu__3,), kwargs = {})
    %layer1_1_conv2_weight_bias : [num_users=1] = get_attr[target=layer1.1.conv2.weight_bias]
    %conv2d_4 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%activation_post_process_10, %activation_post_process_11, %layer1_1_conv2_weight_bias, [1, 1], [1, 1]), kwargs = {})
    %activation_post_process_12 : [num_users=1] = call_module[target=activation_post_process_12](args = (%conv2d_4,), kwargs = {})
    %add__1 : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%activation_post_process_12, %activation_post_process_8), kwargs = {})
    %relu__4 : [num_users=1] = call_function[target=torch.ops.aten.relu_.default](args = (%add__1,), kwargs = {})
    %activation_post_process_13 : [num_users=2] = call_module[target=activation_post_process_13](args = (%relu__4,), kwargs = {})
    %layer2_0_conv1_weight_bias : [num_users=1] = get_attr[target=layer2.0.conv1.weight_bias]
    %conv2d_5 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%activation_post_process_13, %activation_post_process_14, %layer2_0_conv1_weight_bias, [2, 2], [1, 1]), kwargs = {})
    %relu__5 : [num_users=1] = call_function[target=torch.ops.aten.relu_.default](args = (%conv2d_5,), kwargs = {})
    %activation_post_process_15 : [num_users=1] = call_module[target=activation_post_process_15](args = (%relu__5,), kwargs = {})
    %layer2_0_conv2_weight_bias : [num_users=1] = get_attr[target=layer2.0.conv2.weight_bias]
    %conv2d_6 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%activation_post_process_15, %activation_post_process_16, %layer2_0_conv2_weight_bias, [1, 1], [1, 1]), kwargs = {})
    %activation_post_process_17 : [num_users=1] = call_module[target=activation_post_process_17](args = (%conv2d_6,), kwargs = {})
    %layer2_0_downsample_0_weight_bias : [num_users=1] = get_attr[target=layer2.0.downsample.0.weight_bias]
    %conv2d_7 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%activation_post_process_13, %activation_post_process_18, %layer2_0_downsample_0_weight_bias, [2, 2]), kwargs = {})
    %activation_post_process_19 : [num_users=1] = call_module[target=activation_post_process_19](args = (%conv2d_7,), kwargs = {})
    %add__2 : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%activation_post_process_17, %activation_post_process_19), kwargs = {})
    %relu__6 : [num_users=1] = call_function[target=torch.ops.aten.relu_.default](args = (%add__2,), kwargs = {})
    %activation_post_process_20 : [num_users=2] = call_module[target=activation_post_process_20](args = (%relu__6,), kwargs = {})
    %layer2_1_conv1_weight_bias : [num_users=1] = get_attr[target=layer2.1.conv1.weight_bias]
    %conv2d_8 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%activation_post_process_20, %activation_post_process_21, %layer2_1_conv1_weight_bias, [1, 1], [1, 1]), kwargs = {})
    %relu__7 : [num_users=1] = call_function[target=torch.ops.aten.relu_.default](args = (%conv2d_8,), kwargs = {})
    %activation_post_process_22 : [num_users=1] = call_module[target=activation_post_process_22](args = (%relu__7,), kwargs = {})
    %layer2_1_conv2_weight_bias : [num_users=1] = get_attr[target=layer2.1.conv2.weight_bias]
    %conv2d_9 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%activation_post_process_22, %activation_post_process_23, %layer2_1_conv2_weight_bias, [1, 1], [1, 1]), kwargs = {})
    %activation_post_process_24 : [num_users=1] = call_module[target=activation_post_process_24](args = (%conv2d_9,), kwargs = {})
    %add__3 : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%activation_post_process_24, %activation_post_process_20), kwargs = {})
    %relu__8 : [num_users=1] = call_function[target=torch.ops.aten.relu_.default](args = (%add__3,), kwargs = {})
    %activation_post_process_25 : [num_users=2] = call_module[target=activation_post_process_25](args = (%relu__8,), kwargs = {})
    %layer3_0_conv1_weight_bias : [num_users=1] = get_attr[target=layer3.0.conv1.weight_bias]
    %conv2d_10 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%activation_post_process_25, %activation_post_process_26, %layer3_0_conv1_weight_bias, [2, 2], [1, 1]), kwargs = {})
    %relu__9 : [num_users=1] = call_function[target=torch.ops.aten.relu_.default](args = (%conv2d_10,), kwargs = {})
    %activation_post_process_27 : [num_users=1] = call_module[target=activation_post_process_27](args = (%relu__9,), kwargs = {})
    %layer3_0_conv2_weight_bias : [num_users=1] = get_attr[target=layer3.0.conv2.weight_bias]
    %conv2d_11 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%activation_post_process_27, %activation_post_process_28, %layer3_0_conv2_weight_bias, [1, 1], [1, 1]), kwargs = {})
    %activation_post_process_29 : [num_users=1] = call_module[target=activation_post_process_29](args = (%conv2d_11,), kwargs = {})
    %layer3_0_downsample_0_weight_bias : [num_users=1] = get_attr[target=layer3.0.downsample.0.weight_bias]
    %conv2d_12 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%activation_post_process_25, %activation_post_process_30, %layer3_0_downsample_0_weight_bias, [2, 2]), kwargs = {})
    %activation_post_process_31 : [num_users=1] = call_module[target=activation_post_process_31](args = (%conv2d_12,), kwargs = {})
    %add__4 : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%activation_post_process_29, %activation_post_process_31), kwargs = {})
    %relu__10 : [num_users=1] = call_function[target=torch.ops.aten.relu_.default](args = (%add__4,), kwargs = {})
    %activation_post_process_32 : [num_users=2] = call_module[target=activation_post_process_32](args = (%relu__10,), kwargs = {})
    %layer3_1_conv1_weight_bias : [num_users=1] = get_attr[target=layer3.1.conv1.weight_bias]
    %conv2d_13 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%activation_post_process_32, %activation_post_process_33, %layer3_1_conv1_weight_bias, [1, 1], [1, 1]), kwargs = {})
    %relu__11 : [num_users=1] = call_function[target=torch.ops.aten.relu_.default](args = (%conv2d_13,), kwargs = {})
    %activation_post_process_34 : [num_users=1] = call_module[target=activation_post_process_34](args = (%relu__11,), kwargs = {})
    %layer3_1_conv2_weight_bias : [num_users=1] = get_attr[target=layer3.1.conv2.weight_bias]
    %conv2d_14 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%activation_post_process_34, %activation_post_process_35, %layer3_1_conv2_weight_bias, [1, 1], [1, 1]), kwargs = {})
    %activation_post_process_36 : [num_users=1] = call_module[target=activation_post_process_36](args = (%conv2d_14,), kwargs = {})
    %add__5 : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%activation_post_process_36, %activation_post_process_32), kwargs = {})
    %relu__12 : [num_users=1] = call_function[target=torch.ops.aten.relu_.default](args = (%add__5,), kwargs = {})
    %activation_post_process_37 : [num_users=2] = call_module[target=activation_post_process_37](args = (%relu__12,), kwargs = {})
    %layer4_0_conv1_weight_bias : [num_users=1] = get_attr[target=layer4.0.conv1.weight_bias]
    %conv2d_15 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%activation_post_process_37, %activation_post_process_38, %layer4_0_conv1_weight_bias, [2, 2], [1, 1]), kwargs = {})
    %relu__13 : [num_users=1] = call_function[target=torch.ops.aten.relu_.default](args = (%conv2d_15,), kwargs = {})
    %activation_post_process_39 : [num_users=1] = call_module[target=activation_post_process_39](args = (%relu__13,), kwargs = {})
    %layer4_0_conv2_weight_bias : [num_users=1] = get_attr[target=layer4.0.conv2.weight_bias]
    %conv2d_16 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%activation_post_process_39, %activation_post_process_40, %layer4_0_conv2_weight_bias, [1, 1], [1, 1]), kwargs = {})
    %activation_post_process_41 : [num_users=1] = call_module[target=activation_post_process_41](args = (%conv2d_16,), kwargs = {})
    %layer4_0_downsample_0_weight_bias : [num_users=1] = get_attr[target=layer4.0.downsample.0.weight_bias]
    %conv2d_17 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%activation_post_process_37, %activation_post_process_42, %layer4_0_downsample_0_weight_bias, [2, 2]), kwargs = {})
    %activation_post_process_43 : [num_users=1] = call_module[target=activation_post_process_43](args = (%conv2d_17,), kwargs = {})
    %add__6 : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%activation_post_process_41, %activation_post_process_43), kwargs = {})
    %relu__14 : [num_users=1] = call_function[target=torch.ops.aten.relu_.default](args = (%add__6,), kwargs = {})
    %activation_post_process_44 : [num_users=2] = call_module[target=activation_post_process_44](args = (%relu__14,), kwargs = {})
    %layer4_1_conv1_weight_bias : [num_users=1] = get_attr[target=layer4.1.conv1.weight_bias]
    %conv2d_18 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%activation_post_process_44, %activation_post_process_45, %layer4_1_conv1_weight_bias, [1, 1], [1, 1]), kwargs = {})
    %relu__15 : [num_users=1] = call_function[target=torch.ops.aten.relu_.default](args = (%conv2d_18,), kwargs = {})
    %activation_post_process_46 : [num_users=1] = call_module[target=activation_post_process_46](args = (%relu__15,), kwargs = {})
    %layer4_1_conv2_weight_bias : [num_users=1] = get_attr[target=layer4.1.conv2.weight_bias]
    %conv2d_19 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%activation_post_process_46, %activation_post_process_47, %layer4_1_conv2_weight_bias, [1, 1], [1, 1]), kwargs = {})
    %activation_post_process_48 : [num_users=1] = call_module[target=activation_post_process_48](args = (%conv2d_19,), kwargs = {})
    %add__7 : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%activation_post_process_48, %activation_post_process_44), kwargs = {})
    %relu__16 : [num_users=1] = call_function[target=torch.ops.aten.relu_.default](args = (%add__7,), kwargs = {})
    %activation_post_process_49 : [num_users=1] = call_module[target=activation_post_process_49](args = (%relu__16,), kwargs = {})
    %adaptive_avg_pool2d : [num_users=1] = call_function[target=torch.ops.aten.adaptive_avg_pool2d.default](args = (%activation_post_process_49, [1, 1]), kwargs = {})
    %activation_post_process_50 : [num_users=1] = call_module[target=activation_post_process_50](args = (%adaptive_avg_pool2d,), kwargs = {})
    %flatten : [num_users=1] = call_function[target=torch.ops.aten.flatten.using_ints](args = (%activation_post_process_50, 1), kwargs = {})
    %activation_post_process_51 : [num_users=1] = call_module[target=activation_post_process_51](args = (%flatten,), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%activation_post_process_51, %activation_post_process_52, %fc_bias), kwargs = {})
    %activation_post_process_53 : [num_users=1] = call_module[target=activation_post_process_53](args = (%linear,), kwargs = {})
    return (activation_post_process_53,)

校准#

在模型中插入观测者后运行校准函数。校准的目的是运行一些具有代表性的样本示例(例如训练数据集的样本),以便模型中的观测者能够观测张量的统计数据,稍后可以使用这些信息来计算量化参数。

def calibrate(model, data_loader):
    # model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)
calibrate(prepared_model, data_loader_test)  # run calibration on sample data

将校准模型转换为量化模型#

convert_pt2e 接收校准后的模型,并生成量化后的模型。

quantized_model = convert_pt2e(prepared_model)
print(quantized_model)

Hide code cell output

GraphModule(
  (conv1): Module()
  (layer1): Module(
    (0): Module(
      (conv1): Module()
      (conv2): Module()
    )
    (1): Module(
      (conv1): Module()
      (conv2): Module()
    )
  )
  (layer2): Module(
    (0): Module(
      (conv1): Module()
      (conv2): Module()
      (downsample): Module(
        (0): Module()
      )
    )
    (1): Module(
      (conv1): Module()
      (conv2): Module()
    )
  )
  (layer3): Module(
    (0): Module(
      (conv1): Module()
      (conv2): Module()
      (downsample): Module(
        (0): Module()
      )
    )
    (1): Module(
      (conv1): Module()
      (conv2): Module()
    )
  )
  (layer4): Module(
    (0): Module(
      (conv1): Module()
      (conv2): Module()
      (downsample): Module(
        (0): Module()
      )
    )
    (1): Module(
      (conv1): Module()
      (conv2): Module()
    )
  )
  (fc): Module()
)



def forward(self, x):
    x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
    quantize_per_tensor_default = self._frozen_param0
    dequantize_per_tensor_default = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default, 0.0031013814732432365, 0, -127, 127, torch.int8);  quantize_per_tensor_default = None
    quantize_per_tensor_default_1 = self._frozen_param1
    dequantize_per_tensor_default_1 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_1, 0.0029488515574485064, 0, -127, 127, torch.int8);  quantize_per_tensor_default_1 = None
    quantize_per_tensor_default_2 = self._frozen_param2
    dequantize_per_tensor_default_2 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_2, 0.006069142837077379, 0, -127, 127, torch.int8);  quantize_per_tensor_default_2 = None
    quantize_per_tensor_default_3 = self._frozen_param3
    dequantize_per_tensor_default_3 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_3, 0.0022043888457119465, 0, -127, 127, torch.int8);  quantize_per_tensor_default_3 = None
    quantize_per_tensor_default_4 = self._frozen_param4
    dequantize_per_tensor_default_4 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_4, 0.00824717152863741, 0, -127, 127, torch.int8);  quantize_per_tensor_default_4 = None
    quantize_per_tensor_default_5 = self._frozen_param5
    dequantize_per_tensor_default_5 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_5, 0.0016747699119150639, 0, -127, 127, torch.int8);  quantize_per_tensor_default_5 = None
    quantize_per_tensor_default_6 = self._frozen_param6
    dequantize_per_tensor_default_6 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_6, 0.005701970309019089, 0, -127, 127, torch.int8);  quantize_per_tensor_default_6 = None
    quantize_per_tensor_default_7 = self._frozen_param7
    dequantize_per_tensor_default_7 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_7, 0.005450794007629156, 0, -127, 127, torch.int8);  quantize_per_tensor_default_7 = None
    quantize_per_tensor_default_8 = self._frozen_param8
    dequantize_per_tensor_default_8 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_8, 0.0024492174852639437, 0, -127, 127, torch.int8);  quantize_per_tensor_default_8 = None
    quantize_per_tensor_default_9 = self._frozen_param9
    dequantize_per_tensor_default_9 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_9, 0.006906129419803619, 0, -127, 127, torch.int8);  quantize_per_tensor_default_9 = None
    quantize_per_tensor_default_10 = self._frozen_param10
    dequantize_per_tensor_default_10 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_10, 0.001856042305007577, 0, -127, 127, torch.int8);  quantize_per_tensor_default_10 = None
    quantize_per_tensor_default_11 = self._frozen_param11
    dequantize_per_tensor_default_11 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_11, 0.004440308548510075, 0, -127, 127, torch.int8);  quantize_per_tensor_default_11 = None
    quantize_per_tensor_default_12 = self._frozen_param12
    dequantize_per_tensor_default_12 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_12, 0.003213868010789156, 0, -127, 127, torch.int8);  quantize_per_tensor_default_12 = None
    quantize_per_tensor_default_13 = self._frozen_param13
    dequantize_per_tensor_default_13 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_13, 0.002144748345017433, 0, -127, 127, torch.int8);  quantize_per_tensor_default_13 = None
    quantize_per_tensor_default_14 = self._frozen_param14
    dequantize_per_tensor_default_14 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_14, 0.007638747803866863, 0, -127, 127, torch.int8);  quantize_per_tensor_default_14 = None
    quantize_per_tensor_default_15 = self._frozen_param15
    dequantize_per_tensor_default_15 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_15, 0.002374982926994562, 0, -127, 127, torch.int8);  quantize_per_tensor_default_15 = None
    quantize_per_tensor_default_16 = self._frozen_param16
    dequantize_per_tensor_default_16 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_16, 0.009006100706756115, 0, -127, 127, torch.int8);  quantize_per_tensor_default_16 = None
    quantize_per_tensor_default_17 = self._frozen_param17
    dequantize_per_tensor_default_17 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_17, 0.007859906181693077, 0, -127, 127, torch.int8);  quantize_per_tensor_default_17 = None
    quantize_per_tensor_default_18 = self._frozen_param18
    dequantize_per_tensor_default_18 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_18, 0.002309155184775591, 0, -127, 127, torch.int8);  quantize_per_tensor_default_18 = None
    quantize_per_tensor_default_19 = self._frozen_param19
    dequantize_per_tensor_default_19 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_19, 0.028726834803819656, 0, -127, 127, torch.int8);  quantize_per_tensor_default_19 = None
    quantize_per_tensor_default_20 = self._frozen_param20
    dequantize_per_tensor_default_20 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_20, 0.005631787236779928, 0, -127, 127, torch.int8);  quantize_per_tensor_default_20 = None
    fc_bias = self.fc.bias
    quantize_per_tensor_default_21 = torch.ops.quantized_decomposed.quantize_per_tensor.default(x, 0.018649335950613022, -14, -128, 127, torch.int8);  x = None
    dequantize_per_tensor_default_21 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_21, 0.018649335950613022, -14, -128, 127, torch.int8);  quantize_per_tensor_default_21 = None
    conv1_weight_bias = self.conv1.weight_bias
    conv2d = torch.ops.aten.conv2d.default(dequantize_per_tensor_default_21, dequantize_per_tensor_default, conv1_weight_bias, [2, 2], [3, 3]);  dequantize_per_tensor_default_21 = dequantize_per_tensor_default = conv1_weight_bias = None
    relu_ = torch.ops.aten.relu_.default(conv2d);  conv2d = None
    quantize_per_tensor_default_22 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu_, 0.014440659433603287, -128, -128, 127, torch.int8);  relu_ = None
    dequantize_per_tensor_default_22 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_22, 0.014440659433603287, -128, -128, 127, torch.int8);  quantize_per_tensor_default_22 = None
    max_pool2d = torch.ops.aten.max_pool2d.default(dequantize_per_tensor_default_22, [3, 3], [2, 2], [1, 1]);  dequantize_per_tensor_default_22 = None
    quantize_per_tensor_default_23 = torch.ops.quantized_decomposed.quantize_per_tensor.default(max_pool2d, 0.014440659433603287, -128, -128, 127, torch.int8);  max_pool2d = None
    dequantize_per_tensor_default_55 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_23, 0.014440659433603287, -128, -128, 127, torch.int8)
    dequantize_per_tensor_default_54 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_23, 0.014440659433603287, -128, -128, 127, torch.int8);  quantize_per_tensor_default_23 = None
    layer1_0_conv1_weight_bias = getattr(self.layer1, "0").conv1.weight_bias
    conv2d_1 = torch.ops.aten.conv2d.default(dequantize_per_tensor_default_54, dequantize_per_tensor_default_1, layer1_0_conv1_weight_bias, [1, 1], [1, 1]);  dequantize_per_tensor_default_54 = dequantize_per_tensor_default_1 = layer1_0_conv1_weight_bias = None
    relu__1 = torch.ops.aten.relu_.default(conv2d_1);  conv2d_1 = None
    quantize_per_tensor_default_24 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu__1, 0.008789400570094585, -128, -128, 127, torch.int8);  relu__1 = None
    dequantize_per_tensor_default_24 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_24, 0.008789400570094585, -128, -128, 127, torch.int8);  quantize_per_tensor_default_24 = None
    layer1_0_conv2_weight_bias = getattr(self.layer1, "0").conv2.weight_bias
    conv2d_2 = torch.ops.aten.conv2d.default(dequantize_per_tensor_default_24, dequantize_per_tensor_default_2, layer1_0_conv2_weight_bias, [1, 1], [1, 1]);  dequantize_per_tensor_default_24 = dequantize_per_tensor_default_2 = layer1_0_conv2_weight_bias = None
    quantize_per_tensor_default_25 = torch.ops.quantized_decomposed.quantize_per_tensor.default(conv2d_2, 0.023757578805088997, 21, -128, 127, torch.int8);  conv2d_2 = None
    dequantize_per_tensor_default_25 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_25, 0.023757578805088997, 21, -128, 127, torch.int8);  quantize_per_tensor_default_25 = None
    add_ = torch.ops.aten.add_.Tensor(dequantize_per_tensor_default_25, dequantize_per_tensor_default_55);  dequantize_per_tensor_default_25 = dequantize_per_tensor_default_55 = None
    relu__2 = torch.ops.aten.relu_.default(add_);  add_ = None
    quantize_per_tensor_default_26 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu__2, 0.015983207151293755, -128, -128, 127, torch.int8);  relu__2 = None
    dequantize_per_tensor_default_57 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_26, 0.015983207151293755, -128, -128, 127, torch.int8)
    dequantize_per_tensor_default_56 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_26, 0.015983207151293755, -128, -128, 127, torch.int8);  quantize_per_tensor_default_26 = None
    layer1_1_conv1_weight_bias = getattr(self.layer1, "1").conv1.weight_bias
    conv2d_3 = torch.ops.aten.conv2d.default(dequantize_per_tensor_default_56, dequantize_per_tensor_default_3, layer1_1_conv1_weight_bias, [1, 1], [1, 1]);  dequantize_per_tensor_default_56 = dequantize_per_tensor_default_3 = layer1_1_conv1_weight_bias = None
    relu__3 = torch.ops.aten.relu_.default(conv2d_3);  conv2d_3 = None
    quantize_per_tensor_default_27 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu__3, 0.0082000233232975, -128, -128, 127, torch.int8);  relu__3 = None
    dequantize_per_tensor_default_27 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_27, 0.0082000233232975, -128, -128, 127, torch.int8);  quantize_per_tensor_default_27 = None
    layer1_1_conv2_weight_bias = getattr(self.layer1, "1").conv2.weight_bias
    conv2d_4 = torch.ops.aten.conv2d.default(dequantize_per_tensor_default_27, dequantize_per_tensor_default_4, layer1_1_conv2_weight_bias, [1, 1], [1, 1]);  dequantize_per_tensor_default_27 = dequantize_per_tensor_default_4 = layer1_1_conv2_weight_bias = None
    quantize_per_tensor_default_28 = torch.ops.quantized_decomposed.quantize_per_tensor.default(conv2d_4, 0.03155418112874031, 29, -128, 127, torch.int8);  conv2d_4 = None
    dequantize_per_tensor_default_28 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_28, 0.03155418112874031, 29, -128, 127, torch.int8);  quantize_per_tensor_default_28 = None
    add__1 = torch.ops.aten.add_.Tensor(dequantize_per_tensor_default_28, dequantize_per_tensor_default_57);  dequantize_per_tensor_default_28 = dequantize_per_tensor_default_57 = None
    relu__4 = torch.ops.aten.relu_.default(add__1);  add__1 = None
    quantize_per_tensor_default_29 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu__4, 0.018283076584339142, -128, -128, 127, torch.int8);  relu__4 = None
    dequantize_per_tensor_default_59 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_29, 0.018283076584339142, -128, -128, 127, torch.int8)
    dequantize_per_tensor_default_58 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_29, 0.018283076584339142, -128, -128, 127, torch.int8);  quantize_per_tensor_default_29 = None
    layer2_0_conv1_weight_bias = getattr(self.layer2, "0").conv1.weight_bias
    conv2d_5 = torch.ops.aten.conv2d.default(dequantize_per_tensor_default_58, dequantize_per_tensor_default_5, layer2_0_conv1_weight_bias, [2, 2], [1, 1]);  dequantize_per_tensor_default_58 = dequantize_per_tensor_default_5 = layer2_0_conv1_weight_bias = None
    relu__5 = torch.ops.aten.relu_.default(conv2d_5);  conv2d_5 = None
    quantize_per_tensor_default_30 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu__5, 0.007734538055956364, -128, -128, 127, torch.int8);  relu__5 = None
    dequantize_per_tensor_default_30 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_30, 0.007734538055956364, -128, -128, 127, torch.int8);  quantize_per_tensor_default_30 = None
    layer2_0_conv2_weight_bias = getattr(self.layer2, "0").conv2.weight_bias
    conv2d_6 = torch.ops.aten.conv2d.default(dequantize_per_tensor_default_30, dequantize_per_tensor_default_6, layer2_0_conv2_weight_bias, [1, 1], [1, 1]);  dequantize_per_tensor_default_30 = dequantize_per_tensor_default_6 = layer2_0_conv2_weight_bias = None
    quantize_per_tensor_default_31 = torch.ops.quantized_decomposed.quantize_per_tensor.default(conv2d_6, 0.0218511912971735, -7, -128, 127, torch.int8);  conv2d_6 = None
    dequantize_per_tensor_default_31 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_31, 0.0218511912971735, -7, -128, 127, torch.int8);  quantize_per_tensor_default_31 = None
    layer2_0_downsample_0_weight_bias = getattr(getattr(self.layer2, "0").downsample, "0").weight_bias
    conv2d_7 = torch.ops.aten.conv2d.default(dequantize_per_tensor_default_59, dequantize_per_tensor_default_7, layer2_0_downsample_0_weight_bias, [2, 2]);  dequantize_per_tensor_default_59 = dequantize_per_tensor_default_7 = layer2_0_downsample_0_weight_bias = None
    quantize_per_tensor_default_32 = torch.ops.quantized_decomposed.quantize_per_tensor.default(conv2d_7, 0.01700599491596222, 6, -128, 127, torch.int8);  conv2d_7 = None
    dequantize_per_tensor_default_32 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_32, 0.01700599491596222, 6, -128, 127, torch.int8);  quantize_per_tensor_default_32 = None
    add__2 = torch.ops.aten.add_.Tensor(dequantize_per_tensor_default_31, dequantize_per_tensor_default_32);  dequantize_per_tensor_default_31 = dequantize_per_tensor_default_32 = None
    relu__6 = torch.ops.aten.relu_.default(add__2);  add__2 = None
    quantize_per_tensor_default_33 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu__6, 0.013916007243096828, -128, -128, 127, torch.int8);  relu__6 = None
    dequantize_per_tensor_default_61 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_33, 0.013916007243096828, -128, -128, 127, torch.int8)
    dequantize_per_tensor_default_60 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_33, 0.013916007243096828, -128, -128, 127, torch.int8);  quantize_per_tensor_default_33 = None
    layer2_1_conv1_weight_bias = getattr(self.layer2, "1").conv1.weight_bias
    conv2d_8 = torch.ops.aten.conv2d.default(dequantize_per_tensor_default_60, dequantize_per_tensor_default_8, layer2_1_conv1_weight_bias, [1, 1], [1, 1]);  dequantize_per_tensor_default_60 = dequantize_per_tensor_default_8 = layer2_1_conv1_weight_bias = None
    relu__7 = torch.ops.aten.relu_.default(conv2d_8);  conv2d_8 = None
    quantize_per_tensor_default_34 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu__7, 0.008215637877583504, -128, -128, 127, torch.int8);  relu__7 = None
    dequantize_per_tensor_default_34 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_34, 0.008215637877583504, -128, -128, 127, torch.int8);  quantize_per_tensor_default_34 = None
    layer2_1_conv2_weight_bias = getattr(self.layer2, "1").conv2.weight_bias
    conv2d_9 = torch.ops.aten.conv2d.default(dequantize_per_tensor_default_34, dequantize_per_tensor_default_9, layer2_1_conv2_weight_bias, [1, 1], [1, 1]);  dequantize_per_tensor_default_34 = dequantize_per_tensor_default_9 = layer2_1_conv2_weight_bias = None
    quantize_per_tensor_default_35 = torch.ops.quantized_decomposed.quantize_per_tensor.default(conv2d_9, 0.02301773801445961, 11, -128, 127, torch.int8);  conv2d_9 = None
    dequantize_per_tensor_default_35 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_35, 0.02301773801445961, 11, -128, 127, torch.int8);  quantize_per_tensor_default_35 = None
    add__3 = torch.ops.aten.add_.Tensor(dequantize_per_tensor_default_35, dequantize_per_tensor_default_61);  dequantize_per_tensor_default_35 = dequantize_per_tensor_default_61 = None
    relu__8 = torch.ops.aten.relu_.default(add__3);  add__3 = None
    quantize_per_tensor_default_36 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu__8, 0.01587001420557499, -128, -128, 127, torch.int8);  relu__8 = None
    dequantize_per_tensor_default_63 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_36, 0.01587001420557499, -128, -128, 127, torch.int8)
    dequantize_per_tensor_default_62 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_36, 0.01587001420557499, -128, -128, 127, torch.int8);  quantize_per_tensor_default_36 = None
    layer3_0_conv1_weight_bias = getattr(self.layer3, "0").conv1.weight_bias
    conv2d_10 = torch.ops.aten.conv2d.default(dequantize_per_tensor_default_62, dequantize_per_tensor_default_10, layer3_0_conv1_weight_bias, [2, 2], [1, 1]);  dequantize_per_tensor_default_62 = dequantize_per_tensor_default_10 = layer3_0_conv1_weight_bias = None
    relu__9 = torch.ops.aten.relu_.default(conv2d_10);  conv2d_10 = None
    quantize_per_tensor_default_37 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu__9, 0.009096985682845116, -128, -128, 127, torch.int8);  relu__9 = None
    dequantize_per_tensor_default_37 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_37, 0.009096985682845116, -128, -128, 127, torch.int8);  quantize_per_tensor_default_37 = None
    layer3_0_conv2_weight_bias = getattr(self.layer3, "0").conv2.weight_bias
    conv2d_11 = torch.ops.aten.conv2d.default(dequantize_per_tensor_default_37, dequantize_per_tensor_default_11, layer3_0_conv2_weight_bias, [1, 1], [1, 1]);  dequantize_per_tensor_default_37 = dequantize_per_tensor_default_11 = layer3_0_conv2_weight_bias = None
    quantize_per_tensor_default_38 = torch.ops.quantized_decomposed.quantize_per_tensor.default(conv2d_11, 0.02545665204524994, -31, -128, 127, torch.int8);  conv2d_11 = None
    dequantize_per_tensor_default_38 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_38, 0.02545665204524994, -31, -128, 127, torch.int8);  quantize_per_tensor_default_38 = None
    layer3_0_downsample_0_weight_bias = getattr(getattr(self.layer3, "0").downsample, "0").weight_bias
    conv2d_12 = torch.ops.aten.conv2d.default(dequantize_per_tensor_default_63, dequantize_per_tensor_default_12, layer3_0_downsample_0_weight_bias, [2, 2]);  dequantize_per_tensor_default_63 = dequantize_per_tensor_default_12 = layer3_0_downsample_0_weight_bias = None
    quantize_per_tensor_default_39 = torch.ops.quantized_decomposed.quantize_per_tensor.default(conv2d_12, 0.008121605031192303, 35, -128, 127, torch.int8);  conv2d_12 = None
    dequantize_per_tensor_default_39 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_39, 0.008121605031192303, 35, -128, 127, torch.int8);  quantize_per_tensor_default_39 = None
    add__4 = torch.ops.aten.add_.Tensor(dequantize_per_tensor_default_38, dequantize_per_tensor_default_39);  dequantize_per_tensor_default_38 = dequantize_per_tensor_default_39 = None
    relu__10 = torch.ops.aten.relu_.default(add__4);  add__4 = None
    quantize_per_tensor_default_40 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu__10, 0.013726901262998581, -128, -128, 127, torch.int8);  relu__10 = None
    dequantize_per_tensor_default_65 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_40, 0.013726901262998581, -128, -128, 127, torch.int8)
    dequantize_per_tensor_default_64 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_40, 0.013726901262998581, -128, -128, 127, torch.int8);  quantize_per_tensor_default_40 = None
    layer3_1_conv1_weight_bias = getattr(self.layer3, "1").conv1.weight_bias
    conv2d_13 = torch.ops.aten.conv2d.default(dequantize_per_tensor_default_64, dequantize_per_tensor_default_13, layer3_1_conv1_weight_bias, [1, 1], [1, 1]);  dequantize_per_tensor_default_64 = dequantize_per_tensor_default_13 = layer3_1_conv1_weight_bias = None
    relu__11 = torch.ops.aten.relu_.default(conv2d_13);  conv2d_13 = None
    quantize_per_tensor_default_41 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu__11, 0.008119435049593449, -128, -128, 127, torch.int8);  relu__11 = None
    dequantize_per_tensor_default_41 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_41, 0.008119435049593449, -128, -128, 127, torch.int8);  quantize_per_tensor_default_41 = None
    layer3_1_conv2_weight_bias = getattr(self.layer3, "1").conv2.weight_bias
    conv2d_14 = torch.ops.aten.conv2d.default(dequantize_per_tensor_default_41, dequantize_per_tensor_default_14, layer3_1_conv2_weight_bias, [1, 1], [1, 1]);  dequantize_per_tensor_default_41 = dequantize_per_tensor_default_14 = layer3_1_conv2_weight_bias = None
    quantize_per_tensor_default_42 = torch.ops.quantized_decomposed.quantize_per_tensor.default(conv2d_14, 0.025257259607315063, 27, -128, 127, torch.int8);  conv2d_14 = None
    dequantize_per_tensor_default_42 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_42, 0.025257259607315063, 27, -128, 127, torch.int8);  quantize_per_tensor_default_42 = None
    add__5 = torch.ops.aten.add_.Tensor(dequantize_per_tensor_default_42, dequantize_per_tensor_default_65);  dequantize_per_tensor_default_42 = dequantize_per_tensor_default_65 = None
    relu__12 = torch.ops.aten.relu_.default(add__5);  add__5 = None
    quantize_per_tensor_default_43 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu__12, 0.01491590216755867, -128, -128, 127, torch.int8);  relu__12 = None
    dequantize_per_tensor_default_67 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_43, 0.01491590216755867, -128, -128, 127, torch.int8)
    dequantize_per_tensor_default_66 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_43, 0.01491590216755867, -128, -128, 127, torch.int8);  quantize_per_tensor_default_43 = None
    layer4_0_conv1_weight_bias = getattr(self.layer4, "0").conv1.weight_bias
    conv2d_15 = torch.ops.aten.conv2d.default(dequantize_per_tensor_default_66, dequantize_per_tensor_default_15, layer4_0_conv1_weight_bias, [2, 2], [1, 1]);  dequantize_per_tensor_default_66 = dequantize_per_tensor_default_15 = layer4_0_conv1_weight_bias = None
    relu__13 = torch.ops.aten.relu_.default(conv2d_15);  conv2d_15 = None
    quantize_per_tensor_default_44 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu__13, 0.00674060545861721, -128, -128, 127, torch.int8);  relu__13 = None
    dequantize_per_tensor_default_44 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_44, 0.00674060545861721, -128, -128, 127, torch.int8);  quantize_per_tensor_default_44 = None
    layer4_0_conv2_weight_bias = getattr(self.layer4, "0").conv2.weight_bias
    conv2d_16 = torch.ops.aten.conv2d.default(dequantize_per_tensor_default_44, dequantize_per_tensor_default_16, layer4_0_conv2_weight_bias, [1, 1], [1, 1]);  dequantize_per_tensor_default_44 = dequantize_per_tensor_default_16 = layer4_0_conv2_weight_bias = None
    quantize_per_tensor_default_45 = torch.ops.quantized_decomposed.quantize_per_tensor.default(conv2d_16, 0.02473648078739643, 8, -128, 127, torch.int8);  conv2d_16 = None
    dequantize_per_tensor_default_45 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_45, 0.02473648078739643, 8, -128, 127, torch.int8);  quantize_per_tensor_default_45 = None
    layer4_0_downsample_0_weight_bias = getattr(getattr(self.layer4, "0").downsample, "0").weight_bias
    conv2d_17 = torch.ops.aten.conv2d.default(dequantize_per_tensor_default_67, dequantize_per_tensor_default_17, layer4_0_downsample_0_weight_bias, [2, 2]);  dequantize_per_tensor_default_67 = dequantize_per_tensor_default_17 = layer4_0_downsample_0_weight_bias = None
    quantize_per_tensor_default_46 = torch.ops.quantized_decomposed.quantize_per_tensor.default(conv2d_17, 0.019072359427809715, 1, -128, 127, torch.int8);  conv2d_17 = None
    dequantize_per_tensor_default_46 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_46, 0.019072359427809715, 1, -128, 127, torch.int8);  quantize_per_tensor_default_46 = None
    add__6 = torch.ops.aten.add_.Tensor(dequantize_per_tensor_default_45, dequantize_per_tensor_default_46);  dequantize_per_tensor_default_45 = dequantize_per_tensor_default_46 = None
    relu__14 = torch.ops.aten.relu_.default(add__6);  add__6 = None
    quantize_per_tensor_default_47 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu__14, 0.016225755214691162, -128, -128, 127, torch.int8);  relu__14 = None
    dequantize_per_tensor_default_69 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_47, 0.016225755214691162, -128, -128, 127, torch.int8)
    dequantize_per_tensor_default_68 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_47, 0.016225755214691162, -128, -128, 127, torch.int8);  quantize_per_tensor_default_47 = None
    layer4_1_conv1_weight_bias = getattr(self.layer4, "1").conv1.weight_bias
    conv2d_18 = torch.ops.aten.conv2d.default(dequantize_per_tensor_default_68, dequantize_per_tensor_default_18, layer4_1_conv1_weight_bias, [1, 1], [1, 1]);  dequantize_per_tensor_default_68 = dequantize_per_tensor_default_18 = layer4_1_conv1_weight_bias = None
    relu__15 = torch.ops.aten.relu_.default(conv2d_18);  conv2d_18 = None
    quantize_per_tensor_default_48 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu__15, 0.007310453336685896, -128, -128, 127, torch.int8);  relu__15 = None
    dequantize_per_tensor_default_48 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_48, 0.007310453336685896, -128, -128, 127, torch.int8);  quantize_per_tensor_default_48 = None
    layer4_1_conv2_weight_bias = getattr(self.layer4, "1").conv2.weight_bias
    conv2d_19 = torch.ops.aten.conv2d.default(dequantize_per_tensor_default_48, dequantize_per_tensor_default_19, layer4_1_conv2_weight_bias, [1, 1], [1, 1]);  dequantize_per_tensor_default_48 = dequantize_per_tensor_default_19 = layer4_1_conv2_weight_bias = None
    quantize_per_tensor_default_49 = torch.ops.quantized_decomposed.quantize_per_tensor.default(conv2d_19, 0.12780876457691193, -43, -128, 127, torch.int8);  conv2d_19 = None
    dequantize_per_tensor_default_49 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_49, 0.12780876457691193, -43, -128, 127, torch.int8);  quantize_per_tensor_default_49 = None
    add__7 = torch.ops.aten.add_.Tensor(dequantize_per_tensor_default_49, dequantize_per_tensor_default_69);  dequantize_per_tensor_default_49 = dequantize_per_tensor_default_69 = None
    relu__16 = torch.ops.aten.relu_.default(add__7);  add__7 = None
    quantize_per_tensor_default_50 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu__16, 0.09021393954753876, -128, -128, 127, torch.int8);  relu__16 = None
    dequantize_per_tensor_default_50 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_50, 0.09021393954753876, -128, -128, 127, torch.int8);  quantize_per_tensor_default_50 = None
    adaptive_avg_pool2d = torch.ops.aten.adaptive_avg_pool2d.default(dequantize_per_tensor_default_50, [1, 1]);  dequantize_per_tensor_default_50 = None
    quantize_per_tensor_default_51 = torch.ops.quantized_decomposed.quantize_per_tensor.default(adaptive_avg_pool2d, 0.09021393954753876, -128, -128, 127, torch.int8);  adaptive_avg_pool2d = None
    dequantize_per_tensor_default_51 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_51, 0.09021393954753876, -128, -128, 127, torch.int8);  quantize_per_tensor_default_51 = None
    flatten = torch.ops.aten.flatten.using_ints(dequantize_per_tensor_default_51, 1);  dequantize_per_tensor_default_51 = None
    quantize_per_tensor_default_52 = torch.ops.quantized_decomposed.quantize_per_tensor.default(flatten, 0.09021393954753876, -128, -128, 127, torch.int8);  flatten = None
    dequantize_per_tensor_default_52 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_52, 0.09021393954753876, -128, -128, 127, torch.int8);  quantize_per_tensor_default_52 = None
    linear = torch.ops.aten.linear.default(dequantize_per_tensor_default_52, dequantize_per_tensor_default_20, fc_bias);  dequantize_per_tensor_default_52 = dequantize_per_tensor_default_20 = fc_bias = None
    quantize_per_tensor_default_53 = torch.ops.quantized_decomposed.quantize_per_tensor.default(linear, 0.146121546626091, -59, -128, 127, torch.int8);  linear = None
    dequantize_per_tensor_default_53 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_53, 0.146121546626091, -59, -128, 127, torch.int8);  quantize_per_tensor_default_53 = None
    return pytree.tree_unflatten((dequantize_per_tensor_default_53,), self._out_spec)
    
# To see more debug info, please use `graph_module.print_readable()`

量化表示#

Q/DQ 表示#

在当前阶段,提供了两种表示形式供您选择,但长期提供的具体表示形式可能会根据 PyTorch 用户的反馈进行调整。

  • Q/DQ 表示(默认)

  • 之前的文档中,所有量化算子都用 dequantize -> fp32_op -> qauntize 表示

def quantized_linear(x_int8, x_scale, x_zero_point, weight_int8, weight_scale, weight_zero_point, bias_fp32, output_scale, output_zero_point):
    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
             x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
    weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
             weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
    weight_permuted = torch.ops.aten.permute_copy.default(weight_fp32, [1, 0]);
    out_fp32 = torch.ops.aten.addmm.default(bias_fp32, x_fp32, weight_permuted)
    out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
    out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
    return out_i8

参考量化模型表示#

为选定的算子提供特殊表示,例如量化线性。其他算子表示为 dq -> float32_op -> qq/dq 并被分解为更基本的算子。您可以通过使用 convert_pt2e(..., use_reference_representation=True) 获取这种表示。

# Reference Quantized Pattern for quantized linear
def quantized_linear(x_int8, x_scale, x_zero_point, weight_int8, weight_scale, weight_zero_point, bias_fp32, output_scale, output_zero_point):
    x_int16 = x_int8.to(torch.int16)
    weight_int16 = weight_int8.to(torch.int16)
    acc_int32 = torch.ops.out_dtype(torch.mm, torch.int32, (x_int16 - x_zero_point), (weight_int16 - weight_zero_point))
    bias_scale = x_scale * weight_scale
    bias_int32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
    acc_int32 = acc_int32 + bias_int32
    acc_int32 = torch.ops.out_dtype(torch.ops.aten.mul.Scalar, torch.int32, acc_int32, x_scale * weight_scale / output_scale) + output_zero_point
    out_int8 = torch.ops.aten.clamp(acc_int32, qmin, qmax).to(torch.int8)
    return out_int8

检查模型大小和准确度评估#

现在将模型大小和模型精度与基线模型进行比较。

from utils import print_size_of_model, evaluate
# Baseline model size and accuracy
print("Size of baseline model")
print_size_of_model(float_model)

top1, top5 = evaluate(float_model, criterion, data_loader_test)
print("Baseline Float Model Evaluation accuracy: %2.2f, %2.2f"%(top1.avg, top5.avg))

# Quantized model size and accuracy
print("Size of model after quantization")
# export again to remove unused weights
quantized_model = torch.export.export(quantized_model, example_inputs).module()
print_size_of_model(quantized_model)

top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
print("[before serilaization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))
Size of baseline model
Size (MB): 46.828683

Baseline Float Model Evaluation accuracy: 69.23, 88.81
Size of model after quantization
Size (MB): 11.713877
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[8], line 15
     12 quantized_model = torch.export.export(quantized_model, example_inputs).module()
     13 print_size_of_model(quantized_model)
---> 15 top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
     16 print("[before serilaization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))

File /media/pc/data/lxw/ai/torch-book/doc/ecosystem/ExecuTorch/pt2e/utils.py:82, in evaluate(model, criterion, data_loader)
     80 with torch.no_grad():
     81     for image, target in data_loader:
---> 82         output = model(image)
     83         loss = criterion(output, target)
     84         cnt += 1

File /media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/torch/fx/graph_module.py:830, in GraphModule.recompile.<locals>.call_wrapped(self, *args, **kwargs)
    829 def call_wrapped(self, *args, **kwargs):
--> 830     return self._wrapped_call(self, *args, **kwargs)

File /media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/torch/fx/graph_module.py:406, in _WrappedCall.__call__(self, obj, *args, **kwargs)
    404     raise e.with_traceback(None)  # noqa: B904
    405 else:
--> 406     raise e

File /media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/torch/fx/graph_module.py:393, in _WrappedCall.__call__(self, obj, *args, **kwargs)
    391         return self.cls_call(obj, *args, **kwargs)
    392     else:
--> 393         return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
    394 except Exception as e:
    395     assert e.__traceback__

File /media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File /media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/torch/nn/modules/module.py:1857, in Module._call_impl(self, *args, **kwargs)
   1854     return inner()
   1856 try:
-> 1857     return inner()
   1858 except Exception:
   1859     # run always called hooks if they have not already been run
   1860     # For now only forward hooks have the always_call option but perhaps
   1861     # this functionality should be added to full backward hooks as well.
   1862     for hook_id, hook in _global_forward_hooks.items():

File /media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl.<locals>.inner()
   1779 for hook_id, hook in (
   1780     *_global_forward_pre_hooks.items(),
   1781     *self._forward_pre_hooks.items(),
   1782 ):
   1783     if hook_id in self._forward_pre_hooks_with_kwargs:
-> 1784         args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
   1785         if args_kwargs_result is not None:
   1786             if isinstance(args_kwargs_result, tuple) and len(args_kwargs_result) == 2:

File /media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:838, in DisableContext.__call__.<locals>._fn(*args, **kwargs)
    836 _maybe_set_eval_frame(_callback_from_stance(self.callback))
    837 try:
--> 838     return fn(*args, **kwargs)
    839 finally:
    840     set_eval_frame(None)

File /media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/torch/export/_unlift.py:55, in _check_input_constraints_pre_hook(self, args, kwargs)
     51     return
     53 flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec)
---> 55 _check_input_constraints_for_graph(
     56     [node for node in self.graph.nodes if node.op == "placeholder"],
     57     flat_args_with_path,
     58     self.range_constraints,
     59 )

File /media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/torch/_export/utils.py:398, in _check_input_constraints_for_graph(input_placeholders, flat_args_with_path, range_constraints)
    396             continue
    397         elif arg_dim != node_dim:
--> 398             raise RuntimeError(
    399                 f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to "
    400                 f"{node_dim}, but got {arg_dim}",
    401             )
    402 elif isinstance(node_val, (int, float, str)):
    403     if type(arg) != type(node_val) or arg != node_val:

RuntimeError: Expected input at *args[0].shape[0] to be equal to 2, but got 50

小技巧

  1. 现在无法进行性能评估,因为模型尚未下放到目标设备上,它只是 ATen 运算中量化计算的表示。

  2. 目前的权重仍然是 fp32 格式,未来可能会对量化算子进行常量传播,以获得整数权重。

如果你想要获得更好的准确率或性能,可以尝试以不同的方式配置 quantizer ,而每个 quantizer 都会有其自己的配置方式,因此请查阅你所使用的量化器的文档,以了解更多关于如何更好地控制模型量化方法的信息。

保存和加载量化模型#

展示如何保存和加载量化模型:

# 0. Store reference output, for example, inputs, and check evaluation accuracy:
example_inputs = (next(iter(data_loader))[0],)
ref = quantized_model(*example_inputs)
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
print("[before serialization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))

# 1. Export the model and Save ExportedProgram
pt2e_quantized_model_file_path = saved_model_dir + "resnet18_pt2e_quantized.pth"
# capture the model to get an ExportedProgram
quantized_ep = torch.export.export(quantized_model, example_inputs)
# use torch.export.save to save an ExportedProgram
torch.export.save(quantized_ep, pt2e_quantized_model_file_path)


# 2. Load the saved ExportedProgram
loaded_quantized_ep = torch.export.load(pt2e_quantized_model_file_path)
loaded_quantized_model = loaded_quantized_ep.module()

# 3. Check results for example inputs and check evaluation accuracy again:
res = loaded_quantized_model(*example_inputs)
print("diff:", ref - res)

top1, top5 = evaluate(loaded_quantized_model, criterion, data_loader_test)
print("[after serialization/deserialization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))