使用 FX 构建简单的 CPU 性能分析器#

参考:fx_profiling_tutorial

在本教程中,将使用 FX 完成以下任务:

  1. 捕获 PyTorch Python 代码,使其能够检查和收集关于代码结构和执行的统计信息。

  2. 构建一个小类,作为简单的性能“分析器”,从实际运行中收集关于模型每个部分的运行时统计信息。

在本教程中,将使用 torchvision ResNet18 模型进行演示。

import statistics, tabulate, time
from typing import Any, Dict, List
import torch
from torch import fx
from torchvision import models

rn18 = models.resnet18()
rn18.eval()
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)

想要更深入地检查它的性能。也就是说,对于下面的调用,模型的哪个部分花费的时间最长

input = torch.randn(5, 3, 224, 224)
output = rn18(input)

回答这个问题的常用方法是遍历程序源代码,添加在程序中各个点上收集时间戳的代码,并比较这些时间戳之间的差异,以查看时间戳之间的区域需要消耗多长时间。

这种技术当然适用于 PyTorch 代码,但是如果不需要复制模型代码并编辑它就更好了,特别是还没有写过的代码(比如这个 torchvision 模型)。相反,将使用 FX 来自动化这个 “instrumentation” 过程,而不需要修改任何源代码。

备注

tabulate 是外部库,不是 PyTorch 的依赖项。使用它更容易地可视化性能数据。请确保您已经从您最喜欢的 Python 包源安装了它。

用符号跟踪捕获模型#

接下来,将使用 FX 的符号跟踪机制操作和检查的数据结构中捕获模型的定义。

traced_rn18 = torch.fx.symbolic_trace(rn18)
print(traced_rn18.graph)
graph():
    %x : torch.Tensor [#users=1] = placeholder[target=x]
    %conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %bn1 : [#users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
    %relu : [#users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
    %maxpool : [#users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {})
    %layer1_0_conv1 : [#users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {})
    %layer1_0_bn1 : [#users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {})
    %layer1_0_relu : [#users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {})
    %layer1_0_conv2 : [#users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {})
    %layer1_0_bn2 : [#users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {})
    %add : [#users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {})
    %layer1_0_relu_1 : [#users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {})
    %layer1_1_conv1 : [#users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {})
    %layer1_1_bn1 : [#users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {})
    %layer1_1_relu : [#users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {})
    %layer1_1_conv2 : [#users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {})
    %layer1_1_bn2 : [#users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {})
    %add_1 : [#users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {})
    %layer1_1_relu_1 : [#users=2] = call_module[target=layer1.1.relu](args = (%add_1,), kwargs = {})
    %layer2_0_conv1 : [#users=1] = call_module[target=layer2.0.conv1](args = (%layer1_1_relu_1,), kwargs = {})
    %layer2_0_bn1 : [#users=1] = call_module[target=layer2.0.bn1](args = (%layer2_0_conv1,), kwargs = {})
    %layer2_0_relu : [#users=1] = call_module[target=layer2.0.relu](args = (%layer2_0_bn1,), kwargs = {})
    %layer2_0_conv2 : [#users=1] = call_module[target=layer2.0.conv2](args = (%layer2_0_relu,), kwargs = {})
    %layer2_0_bn2 : [#users=1] = call_module[target=layer2.0.bn2](args = (%layer2_0_conv2,), kwargs = {})
    %layer2_0_downsample_0 : [#users=1] = call_module[target=layer2.0.downsample.0](args = (%layer1_1_relu_1,), kwargs = {})
    %layer2_0_downsample_1 : [#users=1] = call_module[target=layer2.0.downsample.1](args = (%layer2_0_downsample_0,), kwargs = {})
    %add_2 : [#users=1] = call_function[target=operator.add](args = (%layer2_0_bn2, %layer2_0_downsample_1), kwargs = {})
    %layer2_0_relu_1 : [#users=2] = call_module[target=layer2.0.relu](args = (%add_2,), kwargs = {})
    %layer2_1_conv1 : [#users=1] = call_module[target=layer2.1.conv1](args = (%layer2_0_relu_1,), kwargs = {})
    %layer2_1_bn1 : [#users=1] = call_module[target=layer2.1.bn1](args = (%layer2_1_conv1,), kwargs = {})
    %layer2_1_relu : [#users=1] = call_module[target=layer2.1.relu](args = (%layer2_1_bn1,), kwargs = {})
    %layer2_1_conv2 : [#users=1] = call_module[target=layer2.1.conv2](args = (%layer2_1_relu,), kwargs = {})
    %layer2_1_bn2 : [#users=1] = call_module[target=layer2.1.bn2](args = (%layer2_1_conv2,), kwargs = {})
    %add_3 : [#users=1] = call_function[target=operator.add](args = (%layer2_1_bn2, %layer2_0_relu_1), kwargs = {})
    %layer2_1_relu_1 : [#users=2] = call_module[target=layer2.1.relu](args = (%add_3,), kwargs = {})
    %layer3_0_conv1 : [#users=1] = call_module[target=layer3.0.conv1](args = (%layer2_1_relu_1,), kwargs = {})
    %layer3_0_bn1 : [#users=1] = call_module[target=layer3.0.bn1](args = (%layer3_0_conv1,), kwargs = {})
    %layer3_0_relu : [#users=1] = call_module[target=layer3.0.relu](args = (%layer3_0_bn1,), kwargs = {})
    %layer3_0_conv2 : [#users=1] = call_module[target=layer3.0.conv2](args = (%layer3_0_relu,), kwargs = {})
    %layer3_0_bn2 : [#users=1] = call_module[target=layer3.0.bn2](args = (%layer3_0_conv2,), kwargs = {})
    %layer3_0_downsample_0 : [#users=1] = call_module[target=layer3.0.downsample.0](args = (%layer2_1_relu_1,), kwargs = {})
    %layer3_0_downsample_1 : [#users=1] = call_module[target=layer3.0.downsample.1](args = (%layer3_0_downsample_0,), kwargs = {})
    %add_4 : [#users=1] = call_function[target=operator.add](args = (%layer3_0_bn2, %layer3_0_downsample_1), kwargs = {})
    %layer3_0_relu_1 : [#users=2] = call_module[target=layer3.0.relu](args = (%add_4,), kwargs = {})
    %layer3_1_conv1 : [#users=1] = call_module[target=layer3.1.conv1](args = (%layer3_0_relu_1,), kwargs = {})
    %layer3_1_bn1 : [#users=1] = call_module[target=layer3.1.bn1](args = (%layer3_1_conv1,), kwargs = {})
    %layer3_1_relu : [#users=1] = call_module[target=layer3.1.relu](args = (%layer3_1_bn1,), kwargs = {})
    %layer3_1_conv2 : [#users=1] = call_module[target=layer3.1.conv2](args = (%layer3_1_relu,), kwargs = {})
    %layer3_1_bn2 : [#users=1] = call_module[target=layer3.1.bn2](args = (%layer3_1_conv2,), kwargs = {})
    %add_5 : [#users=1] = call_function[target=operator.add](args = (%layer3_1_bn2, %layer3_0_relu_1), kwargs = {})
    %layer3_1_relu_1 : [#users=2] = call_module[target=layer3.1.relu](args = (%add_5,), kwargs = {})
    %layer4_0_conv1 : [#users=1] = call_module[target=layer4.0.conv1](args = (%layer3_1_relu_1,), kwargs = {})
    %layer4_0_bn1 : [#users=1] = call_module[target=layer4.0.bn1](args = (%layer4_0_conv1,), kwargs = {})
    %layer4_0_relu : [#users=1] = call_module[target=layer4.0.relu](args = (%layer4_0_bn1,), kwargs = {})
    %layer4_0_conv2 : [#users=1] = call_module[target=layer4.0.conv2](args = (%layer4_0_relu,), kwargs = {})
    %layer4_0_bn2 : [#users=1] = call_module[target=layer4.0.bn2](args = (%layer4_0_conv2,), kwargs = {})
    %layer4_0_downsample_0 : [#users=1] = call_module[target=layer4.0.downsample.0](args = (%layer3_1_relu_1,), kwargs = {})
    %layer4_0_downsample_1 : [#users=1] = call_module[target=layer4.0.downsample.1](args = (%layer4_0_downsample_0,), kwargs = {})
    %add_6 : [#users=1] = call_function[target=operator.add](args = (%layer4_0_bn2, %layer4_0_downsample_1), kwargs = {})
    %layer4_0_relu_1 : [#users=2] = call_module[target=layer4.0.relu](args = (%add_6,), kwargs = {})
    %layer4_1_conv1 : [#users=1] = call_module[target=layer4.1.conv1](args = (%layer4_0_relu_1,), kwargs = {})
    %layer4_1_bn1 : [#users=1] = call_module[target=layer4.1.bn1](args = (%layer4_1_conv1,), kwargs = {})
    %layer4_1_relu : [#users=1] = call_module[target=layer4.1.relu](args = (%layer4_1_bn1,), kwargs = {})
    %layer4_1_conv2 : [#users=1] = call_module[target=layer4.1.conv2](args = (%layer4_1_relu,), kwargs = {})
    %layer4_1_bn2 : [#users=1] = call_module[target=layer4.1.bn2](args = (%layer4_1_conv2,), kwargs = {})
    %add_7 : [#users=1] = call_function[target=operator.add](args = (%layer4_1_bn2, %layer4_0_relu_1), kwargs = {})
    %layer4_1_relu_1 : [#users=1] = call_module[target=layer4.1.relu](args = (%add_7,), kwargs = {})
    %avgpool : [#users=1] = call_module[target=avgpool](args = (%layer4_1_relu_1,), kwargs = {})
    %flatten : [#users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {})
    %fc : [#users=1] = call_module[target=fc](args = (%flatten,), kwargs = {})
    return fc

这为我们提供了 ResNet18 模型的计算图表示。Graph 由一系列相互连接的节点组成。每个 Node 表示 Python 代码中的一个调用站点(无论是函数、模块还是方法),而边(在每个节点上表示为 argskwargs)表示在这些调用站点之间传递的值。更多关于 Graph 表示和其他 FX API 的信息可以在 FX 文档 中找到。

创建性能分析解释器#

接下来,将创建继承自 Interpreter 的类。虽然 symbolic_trace 生成的 GraphModule 编译调用 GraphModule 时运行的 Python 代码,但运行 GraphModule 的另一种方法是逐个执行 ~torch.fx.Graph 中的每个节点。这就是 Interpreter 提供的功能:它逐节点(node-by-node)地解释 graph。

通过从 Interpreter 继承,可以覆盖各种功能并安装我们想要的分析行为。目标是拥有一个对象,我们可以向其传递模型,调用模型 1 次或更多次,然后获得关于模型和模型的每个部分在这些运行期间所花费的时间的统计数据。

class ProfilingInterpreter(fx.Interpreter):
    def __init__(self, mod: torch.nn.Module):
        # 将在构造函数中执行,而不是让用户 symbolically 地跟踪他们的模型。
        # 因此,用户可以传入任何 ``Module``,而不必担心符号跟踪 API
        gm = fx.symbolic_trace(mod)
        super().__init__(gm)
        
        # 要在这里储存两件东西:
        #
        # 1. ``mod`` 的总运行时列表。换句话说,存储了每次 ``mod(...)`` 调用这个解释器时所用的时间 。
        self.total_runtime_sec: List[float] = []
        # 2. 从 ``Node`` 到节点运行时间列表(以秒为单位)的映射。
        # 这与 (1) 类似,但只是针对模型的特定子部分。
        self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {}

    ######################################################################
    # 接下来,让重写 ``run()`` 方法。
    # ``Interpreter`` 的 ``run`` 方法是模型执行的顶层入口点。
    # 我们想要拦截它,这样我们就可以记录模型的总运行时。

    def run(self, *args) -> Any:
        # 记录开始运行模型的时间
        t_start = time.time()
        # 通过将模型委托回 Interpreter.run() 来运行模型
        return_val = super().run(*args)
        # 记录完成运行模型的时间
        t_end = time.time()
        # 存储这个模型执行的总耗时
        self.total_runtime_sec.append(t_end - t_start)
        return return_val

    ######################################################################
    # 现在,重写 ``run_node``。
    # ``Interpreter`` 每次执行单个节点时调用 ``run_node``。
    # 拦截它,以度量和记录模型中每个单独(individual)调用所花费的时间。
    def run_node(self, n: torch.fx.Node) -> Any:
        # 记录下开始运行 op 的时间
        t_start = time.time()
        # 通过将委托 Interpreter.run_node() 来运行 op
        return_val = super().run_node(n)
        # 记录完成 op 运行的时间
        t_end = time.time()
        # 如果在 runtimes_sec 数据结构中没有此节点,则添加一个列表值为空的项。
        self.runtimes_sec.setdefault(n, [])
        # 在 runtimes_sec 数据结构中记录这一次调用的总运行时间
        self.runtimes_sec[n].append(t_end - t_start)
        return return_val

    ######################################################################
    # 最后,将定义一个方法(一个不覆盖任何 ``Interpreter`` 方法的方法),
    # 它为收集数据提供了一个很好的、有组织的视图。

    def summary(self, should_sort : bool = False) -> str:
        # 为每个节点建立汇总(summary)信息列表
        node_summaries: List[List[Any]] = []
        # 计算整个网络的平均运行时间。
        # 因为在分析过程中可能多次调用网络,所以需要总结运行时。
        # 选择使用算术平均值。
        mean_total_runtime = statistics.mean(self.total_runtime_sec)
        
        # 对每个节点,记录汇总统计信息
        for node, runtimes in self.runtimes_sec.items():
            # 类似地,计算 ``node`` 的平均运行时
            mean_runtime = statistics.mean(runtimes)
            # 为了便于理解,还计算了每个节点相对于整个网络所花费的时间百分比。
            pct_total = mean_runtime / mean_total_runtime * 100
            # 记录节点类型、节点名称、平均运行时间和运行时间百分比
            node_summaries.append([node.op, str(node), 
                                   mean_runtime, pct_total])

        # 在进行性能分析时,要回答的最重要的问题之一是“哪个运算花费的时间最长?”。
        # 通过在摘要视图中提供排序功能,可以使这一点更容易看到
        if should_sort:
            node_summaries.sort(key=lambda s: s[2], reverse=True)
        
        # 使用 ``tabulate`` 库创建格式良好的表来显示摘要信息
        headers : List[str] = ['Op type', 'Op', 
                               'Average runtime (s)', 
                               'Pct total runtime']
        return tabulate.tabulate(node_summaries, headers=headers)

备注

使用 Python 的 time.time 函数提取 clock 时间戳并进行比较。这不是衡量性能的最精确的方法,只能给我们一个一阶近似。我们使用这种简单的技术只是为了在本教程中进行演示。

ResNet18的性能研究#

现在可以使用 ProfilingInterpreter 来检查 ResNet18 模型的性能特征:

interp = ProfilingInterpreter(rn18)
interp.run(input)
print(interp.summary(True))
Op type        Op                       Average runtime (s)    Pct total runtime
-------------  ---------------------  ---------------------  -------------------
call_module    conv1                            0.00634789            10.7596
call_module    maxpool                          0.00393319             6.66672
call_module    layer4_0_conv2                   0.00305033             5.17027
call_module    layer4_1_conv2                   0.00285316             4.83607
call_module    layer4_1_conv1                   0.00277543             4.70433
call_module    layer1_0_conv2                   0.00272179             4.6134
call_module    layer1_1_conv2                   0.00269055             4.56046
call_module    bn1                              0.00231075             3.9167
call_module    layer3_0_conv2                   0.00221777             3.7591
call_module    layer3_1_conv2                   0.00215101             3.64594
call_module    layer3_1_conv1                   0.00209332             3.54815
call_module    layer1_0_conv1                   0.00208449             3.5332
call_module    layer2_1_conv2                   0.00204778             3.47096
call_module    layer1_1_conv1                   0.00199175             3.37599
call_module    layer2_1_conv1                   0.00197744             3.35175
call_module    layer2_0_conv2                   0.00197124             3.34124
call_module    layer4_0_conv1                   0.00188589             3.19657
call_module    layer2_0_conv1                   0.00169086             2.866
call_module    layer3_0_conv1                   0.00168705             2.85953
call_module    layer2_0_downsample_0            0.000978947            1.6593
call_module    layer3_0_downsample_0            0.000718355            1.2176
call_function  add                              0.000567913            0.962607
call_module    layer1_1_bn2                     0.000564814            0.957354
call_module    layer4_0_downsample_0            0.000509262            0.863194
call_function  add_1                            0.000474691            0.804597
call_module    relu                             0.000412941            0.699931
call_module    layer2_1_bn2                     0.000388384            0.658307
call_function  add_3                            0.000330925            0.560915
call_module    layer1_0_bn2                     0.000254631            0.431597
call_module    layer1_0_bn1                     0.000241041            0.408562
call_module    layer1_1_bn1                     0.000214815            0.36411
call_module    fc                               0.000196218            0.332588
call_module    layer4_1_bn2                     0.000154018            0.26106
call_module    layer4_0_bn2                     0.00015378             0.260656
call_module    avgpool                          0.000152349            0.258231
call_module    layer3_0_bn2                     0.00014019             0.237621
call_module    layer4_1_bn1                     0.00013876             0.235196
call_function  add_2                            0.000137806            0.23358
call_module    layer2_0_bn1                     0.000136614            0.231559
call_module    layer2_0_bn2                     0.000135422            0.229539
call_module    layer2_0_downsample_1            0.000133991            0.227114
call_module    layer4_0_downsample_1            0.000130177            0.220648
call_module    layer3_1_bn1                     0.000128984            0.218627
call_module    layer3_0_downsample_1            0.000124693            0.211353
call_function  add_5                            0.000124454            0.210949
call_module    layer3_0_bn1                     0.000123978            0.210141
call_module    layer2_1_bn1                     0.000123739            0.209737
call_module    layer3_1_bn2                     0.00012207             0.206908
call_module    layer1_0_relu                    0.000121832            0.206504
call_module    layer4_0_bn1                     0.000119448            0.202463
call_module    layer1_1_relu                    0.000116825            0.198017
call_module    layer2_0_relu                    0.000105381            0.17862
call_module    layer2_1_relu                    0.000101328            0.17175
call_function  add_4                            9.9659e-05             0.168921
call_module    layer3_1_relu                    9.63211e-05            0.163263
call_function  add_6                            9.63211e-05            0.163263
call_module    layer3_0_relu                    9.58443e-05            0.162455
call_module    layer4_0_relu                    9.53674e-05            0.161647
call_function  add_7                            9.27448e-05            0.157202
call_module    layer1_0_relu_1                  9.15527e-05            0.155181
call_module    layer2_1_relu_1                  9.05991e-05            0.153565
call_module    layer4_1_relu                    8.98838e-05            0.152352
call_module    layer1_1_relu_1                  8.82149e-05            0.149523
call_module    layer2_0_relu_1                  8.10623e-05            0.1374
call_module    layer4_0_relu_1                  7.51019e-05            0.127297
call_module    layer3_0_relu_1                  7.4625e-05             0.126489
call_module    layer4_1_relu_1                  7.43866e-05            0.126085
call_module    layer3_1_relu_1                  7.20024e-05            0.122043
call_function  flatten                          4.41074e-05            0.0747617
placeholder    x                                2.47955e-05            0.0420282
output         output                           1.74046e-05            0.0295006

小技巧

  1. torch.nn.MaxPool2d 占用的时间最多(问题细节)。

  2. torch.nn.BatchNorm2d 也占用大量时间。可以进行 BN 融合 以提高性能。

record_function()#

正常执行时正确记录 foo 范围

import torch
import torch.fx
from torch.autograd import profiler

# Setup: a module with `record_function`
class Foo(torch.nn.Module):
  def forward(self, x):
    with profiler.record_function('foo'):
      return torch.relu(x)

f = Foo()
x = torch.randn(5, 3, 2)
with profiler.profile() as prof:
  f(x)

print(prof)
-------------------  ------------  ------------  ------------  ------------  ------------  ------------  
               Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------  ------------  ------------  ------------  ------------  ------------  ------------  
        aten::zeros         3.43%      17.000us         4.85%      24.000us      24.000us             1  
        aten::empty         1.41%       7.000us         1.41%       7.000us       7.000us             1  
        aten::zero_         0.00%       0.000us         0.00%       0.000us       0.000us             1  
                foo        87.27%     432.000us        95.15%     471.000us     471.000us             1  
        aten::empty         0.61%       3.000us         0.61%       3.000us       3.000us             1  
         aten::relu         4.04%      20.000us         7.27%      36.000us      36.000us             1  
    aten::clamp_min         3.23%      16.000us         3.23%      16.000us      16.000us             1  
-------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 495.000us

FX 跟踪不记录 foo 范围

traced = fx.symbolic_trace(f)
with profiler.profile() as prof:
  traced(x)

print(prof)
-------------------  ------------  ------------  ------------  ------------  ------------  ------------  
               Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------  ------------  ------------  ------------  ------------  ------------  ------------  
         aten::relu        36.84%       7.000us       100.00%      19.000us      19.000us             1  
    aten::clamp_min        63.16%      12.000us        63.16%      12.000us      12.000us             1  
-------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 19.000us

自定义追踪器:

class ProfilerTracer(fx.Tracer):
  def trace(self, root, concrete_args=None):
    orig_record_function_enter = profiler.record_function.__enter__
    orig_record_function_exit = profiler.record_function.__exit__

    def fake_profiler_enter(_self):
      nonlocal self
      handle_proxy = self.create_proxy(
          kind='call_function',
          target=torch.ops.profiler._record_function_enter,
          args=(_self.name,),
          kwargs={})
      
      assert getattr(_self, '_fx_profiler_ctx', None) is None
      setattr(_self, '_fx_profiler_ctx', handle_proxy)
      return handle_proxy

    def fake_profiler_exit(_self, exc_type, exc_value, traceback):
      assert hasattr(_self, '_fx_profiler_ctx')
      handle_proxy = _self._fx_profiler_ctx
      torch.ops.profiler._record_function_exit(handle_proxy)
      setattr(_self, '_fx_profiler_ctx', None)

    profiler.record_function.__enter__ = fake_profiler_enter
    profiler.record_function.__exit__ = fake_profiler_exit

    try:
      return super().trace(root, concrete_args)
    finally:
      profiler.record_function.__enter__ = orig_record_function_enter
      profiler.record_function.__exit__ = orig_record_function_exit

pt = ProfilerTracer()

graph_with_profiler = pt.trace(f)
traced_with_profiler = fx.GraphModule(pt.root, graph_with_profiler)

with profiler.profile() as prof:
  traced_with_profiler(x)

print(prof)
-------------------  ------------  ------------  ------------  ------------  ------------  ------------  
               Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                foo        95.05%     307.000us       100.00%     323.000us     323.000us             1  
        aten::empty         1.24%       4.000us         1.24%       4.000us       4.000us             1  
         aten::relu         1.24%       4.000us         3.72%      12.000us      12.000us             1  
    aten::clamp_min         2.48%       8.000us         2.48%       8.000us       8.000us             1  
-------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 323.000us