FX 简介#

FX 是供开发人员用来转换 Module 实例的工具包。FX 由三个主要组件组成:符号跟踪器(symbolic tracer)、中间表示(intermediate representation,简写 IR)和 Python 代码生成(Python code generation)。

import torch
from torch.fx import symbolic_trace

# 用于演示的简单模块
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

module = MyModule()

符号跟踪前端(Symbolic tracing frontend)

备注

符号跟踪器 执行 Python 代码的“符号执行”。它通过代码提供虚假的值,称为 代理。记录对这些代理的运算。有关符号跟踪的更多信息可以在 symbolic_trace()Tracer 文档中找到。

捕获模块的语义:

symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

高级中间表示(intermediate representationIR)

备注

中间表示 是符号跟踪期间记录的运算的容器。它由一组 node 组成,这些 node 表示函数输入、调用站点(callsites,即函数、方法或 Module 实例)和返回值。关于 IR 的更多信息可以在 Graph 的文档中找到。IR 是应用变换(transformations)的格式。

计算图(graph)表示:

print(symbolic_traced.graph)
graph():
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp

代码生成(Code generation)

备注

Python 代码生成 使 FX 成为 Python 到 Python (或 Module-to-Module)的变换工具包。对于每个 Graph IR,可以创建与 Graph 语义匹配的有效 Python 代码。该功能封装在 GraphModule 中,它是 Module 实例,包含 Graph 以及从 Graph 生成的 forward() 方法。

有效的 Python 代码:

print(symbolic_traced.code)
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
    

总的来说,这个组件管道(symbolic tracing -> intermediate representation -> transforms -> Python code generation)构成了 FX 的 Python-to-Python 变换管道(pipeline)。此外,这些组件可以单独使用。例如,可以单独使用符号跟踪来捕获代码的形式,以便进行分析(而不是变换)。代码生成可以用于以编程方式生成模型,例如从配置文件生成模型。

编写变换#

什么是 FX 变换?本质上,它是这样的函数:

import torch
from torch import nn
import torch.fx

def transform(m: nn.Module,
              tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
    # 步骤 1:获取表示 `m` 代码的计算图表示

    # NOTE: torch.fx.symbolic_trace 是对 fx.Tracer.trace 调用和构造 GraphModule 的包装器。
    # 将在变换中分离它,以允许调用者自定义 tracing 行为。
    graph : torch.fx.Graph = tracer_class().trace(m)

    # 步骤 2: 修改此 Graph 或创建新的 Graph
    graph = ...

    # 步骤 3:返回构造的 Module
    return torch.fx.GraphModule(m, graph)

transformation 函数需要 Module 作为输入, 然后从该 Module 获得 Graph (即 IR)对其进行修改, 然后返回新的 Module。你应该把返回的 Module 想成和正常的 Module 一样:你可以把它传递给另一个 FX 变换,你可以把它传递给 TorchScript,或者你可以运行它。确保 FX 变换的输入和输出是 Module 将允许可组合性。

备注

也可以修改现有的 GraphModule,而不是创建新的 GraphModule,如下所示:

import torch
import torch.fx

def transform(m : nn.Module) -> nn.Module:
    gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)

    # Modify gm.graph
    # <...>

    # Recompile the forward() method of `gm` from its Graph
    gm.recompile()

    return gm

小技巧

注意,你必须调用 recompile() 来将 GraphModule 上生成的 forward() 方法与修改后的 Graph 同步。

假设您已经传入了被跟踪到 Graph 中的 Module,那么现在您可以采用两种主要方法来构建新的 Graph

Graph 入门#

Graph 的语义可以在 Graph 文档中找到完整的处理方法,但是在这里只介绍基础知识。Graph 是一个数据结构,表示 GraphModule 上的方法。这需要的信息是:

  • 此方法的输入是什么?

  • 此方法当中执行了哪些运算?

  • 此方法的输出是什么?

这三个概念都用 Node 实例表示。

用简短的例子来看看这是什么意思:

import torch
from torch import fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), 
                                    dim=-1), 3)

m = MyModule()
gm = fx.symbolic_trace(m)

gm.graph.print_tabular()
opcode         name           target                                                   args                kwargs
-------------  -------------  -------------------------------------------------------  ------------------  -----------
placeholder    x              x                                                        ()                  {}
get_attr       linear_weight  linear.weight                                            ()                  {}
call_function  add            <built-in function add>                                  (x, linear_weight)  {}
call_module    linear         linear                                                   (add,)              {}
call_method    relu           relu                                                     (linear,)           {}
call_function  sum_1          <built-in method sum of type object at 0x7fafc504a200>   (relu,)             {'dim': -1}
call_function  topk           <built-in method topk of type object at 0x7fafc504a200>  (sum_1, 3)          {}
output         output         output                                                   (topk,)             {}

这里定义了模块 MyModule,用于演示,实例化它,象征性地跟踪它,然后调用 print_tabular() 方法打印出一个表,显示这个图的节点。

可以使用这些信息来回答上面提出的问题。

上述表格足以回答我们的三个问题:

  1. 这个方法的输入是什么?在 FX 中, 方法输入被表示为 placeholder 节点。在我们的例子中,只有一个 placeholder,可以推断出来我们的 forward 的函数除了首参数 self 外只有一个额外的输入(即 x)。

  2. 这个方法当中执行了哪些运算?我们可以看到 get_attrcall_funcationcall_module 等节点表示了方法中的运算。

  3. 这个方法的输出是什么?我们使用特别的 output 来表示 Graph 的输出。

现在知道了方法是如何在 torch.fx 中被记录表示的, 下一步便是通过 Graph 修改它。