操控 Graph#

构建新 Graph 的一种方法是直接操控旧图。为了帮助实现这一点,可以简单地从符号跟踪中获取 Graph 并对其进行修改。例如,假设希望用 torch.mul() 调用替换 torch.add() 调用。

import torch
from torch import fx

# 样例模块
class M(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)

查看节点信息:

m = M()
gm: torch.fx.GraphModule = torch.fx.symbolic_trace(m)
for node in gm.graph.nodes:
    print(node, node.op, node.target)
x placeholder x
y placeholder y
add call_function <built-in method add of type object at 0x7f5bb4da7200>
output output output
def transform(m: torch.nn.Module,
              tracer_class: type = fx.Tracer) -> torch.nn.Module:
    graph: fx.Graph = tracer_class().trace(m)
    # FX 将其 Graph 表示为节点的有序列表,因此可以遍历它们。
    for node in graph.nodes:
        # 检查是否正在调用函数(例如:torch.add)
        if node.op == 'call_function':
            # target 属性是 call_function 调用的函数。
            if node.target == torch.add:
                node.target = torch.mul

    graph.lint() # 做一些检查,以确保 Graph 是格式良好的。
    return fx.GraphModule(m, graph)

或者使用更简洁的写法:

m = M()
traced: torch.fx.GraphModule = torch.fx.symbolic_trace(m)
for node in traced.graph.nodes:
    if node.op == 'call_function':
        # target 属性是 call_function 调用的函数。
        if node.target == torch.add:
            node.target = torch.mul
traced.graph.lint() # 做一些检查,以确保 Graph 是格式良好的。
traced.graph.print_tabular()
opcode         name    target                                                  args    kwargs
-------------  ------  ------------------------------------------------------  ------  --------
placeholder    x       x                                                       ()      {}
placeholder    y       y                                                       ()      {}
call_function  add     <built-in method mul of type object at 0x7f5bb4da7200>  (x, y)  {}
output         output  output                                                  (add,)  {}

还可以进行更复杂的 Graph 重写,比如删除或追加节点。为了帮助完成这些变换,FX 提供了变换 Graph 的实用函数。下面是使用这些 API 附加 relu() 调用的示例。

def inserting_after(node, new_node=torch.relu):
    """指定插入点,并在此范围内添加到 Graph 中的任何节点都将插入到 `node` 之后"""
    with traced.graph.inserting_after(node):
        # 插入新的 `call_function` 节点调用 `torch.relu``
        new_node = traced.graph.call_function(new_node, args=(node,))
         
        # 希望所有使用 `node` 值的节点后添加 `relu` 回调
        # 使用 `replace_all_uses_with` API 来做到这一点。
        node.replace_all_uses_with(new_node)

对于仅由替换组成的简单变换,还可以使用 torch.fx.subgraph_rewriter

replace_pattern() 重写子图#

FX 在直接 Graph 操作的基础上还提供了另一个自动化级别。replace_pattern() API 本质上是编辑 Graph 的“查找/替换”工具。它允许您指定 patternreplacement,它将跟踪这些函数,在 pattern graph 中查找运算组的实例,并用 replacement graph 的副本替换这些实例。这有助于极大地自动化繁琐的 graph 操作代码,随着变换变得更加复杂,这些代码可能会变得笨拙。

import torch
from torch.fx import symbolic_trace, subgraph_rewriter

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, w1, w2):
        m1 = torch.cat([w1, w2]).sum()
        m2 = torch.cat([w1, w2]).sum()
        return x + torch.max(m1) + torch.max(m2)

def pattern(w1, w2):
    return torch.cat([w1, w2]).sum()

def replacement(w1, w2):
    return torch.stack([w1, w2])

traced_module = symbolic_trace(M())

subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
[Match(anchor=max_1, nodes_map={output: max_1, sum_1: sum_1, cat: cat, w1: w1, w2: w2}),
 Match(anchor=max_2, nodes_map={output: max_2, sum_1: sum_2, cat: cat_1, w1: w1, w2: w2})]

Proxy/Retracing#

另一种操作 Graph 的方法是重用符号跟踪中使用的 Proxy 机制。例如,假设想要编写一个变换,将 PyTorch 函数分解为更小的运算。它将把每个 F.relu(x) 调用变换为 (x > 0) * x。一种可能是执行必要的 graph 重写,在 F.relu 之后插入比较和乘法,然后清理原来的 F.relu。但是,可以通过使用 Proxy 对象自动地将运算记录到 Graph 中来自动化这个过程。

要使用此方法,将希望插入的运算编写为常规 PyTorch 代码,并使用 Proxy 对象作为参数调用该代码。这些代理对象将捕获对它们执行的操作,并将它们附加到 Graph 中。

from torch import fx
from torch.nn import functional as F

# 注意,这个分解(decomposition)规则可以理解为普通的Python
def relu_decomposition(x):
    return (x > 0) * x

decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition

def decompose(model: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
    """
    将 `model` 分解为更小的复合运算。
    目前,它只支持将 ReLU 分解为它的数学定义:(x > 0) * x
    """
    graph : fx.Graph = tracer_class().trace(model)
    new_graph = fx.Graph()
    env = {}
    tracer = fx.proxy.GraphAppendingTracer(graph)
    for node in graph.nodes:
        if node.op == 'call_function' and node.target in decomposition_rules:
            # 通过使用代理包装参数,可以分派到适当的分解规则,
            # 并通过符号跟踪隐式地将其添加到 Graph 中。
            proxy_args = [fx.Proxy(env[x.name], tracer) 
                          if isinstance(x, fx.Node) else x for x in node.args]
            output_proxy = decomposition_rules[node.target](*proxy_args)
            
            # 对 `Proxy` 的运算总是产生新的 `Proxy`,分解规则的返回值也不例外。
            # 需要从 `Proxy` 中提取底层的 `Node`,以便在此变换的后续迭代中使用它。
            new_node = output_proxy.node
            env[node.name] = new_node
        else:
            # 默认情况:没有此节点的分解规则,所以只需要将它复制到新的 Graph 中。
            new_node = new_graph.node_copy(node, lambda x: env[x.name])
            env[node.name] = new_node
    return fx.GraphModule(model, new_graph)

除了避免显式的 Graph 操作之外,使用 Proxy 还允许将重写规则指定为原生 Python 代码。对于需要大量重写规则的变换(如 vmap 或 grad),这通常可以提高规则的可读性和可维护性。注意,在调用 Proxy 时,还传递了指向底层变量 graph 的跟踪器。如果 graph 中的操作是 n-ary 的(例如 add 是二进制算子),那么调用 Proxy 不会创建 graph 跟踪器的多个实例,这会导致意外的运行时错误。推荐这种使用 Proxy 的方法,特别是当底层算子不能被安全地假定为一元的时候。

如何使用代理对象创建计算图#

可以直接在原始节点周围创建代理对象。这可用于创建独立于符号跟踪的 Graph

下面的代码演示了如何使用带有原始节点的代理将运算附加到新 Graph。将创建两个参数( xy ),对这些参数执行一些运算,然后将创建的所有内容添加到新的 Graph中。然后将把这个 Graph 包装到 {class}~torch.fx.GraphModule` 中。这样做会创建 nn.Module 的可运行实例。

创建独立于符号跟踪的计算图

graph = fx.Graph()
tracer = fx.proxy.GraphAppendingTracer(graph)

创建输入节点:

raw1 = graph.placeholder('x')
raw2 = graph.placeholder('y')

使用原始节点和图的默认跟踪器初始化代理

y = fx.Proxy(raw1, tracer)
z = fx.Proxy(raw2, tracer)

生成其他运算:

a = torch.cat([y, z])
b = torch.tanh(a)
c = torch.neg(b)
z = torch.add(b, c)

创建新的输出节点并将其添加到图中。通过这样做,图将包含刚刚创建的所有节点(因为它们都链接到输出节点).

graph.output(c.node)
output

将创建的图包装到 GraphModule 中,以获得最终的、可运行的 Module 的实例

mod = fx.GraphModule(torch.nn.Module(), graph)
print(mod.code)
def forward(self, x, y):
    cat = torch.cat([x, y])
    tanh = torch.tanh(cat);  cat = None
    neg = torch.neg(tanh);  tanh = None
    cat_1 = torch.cat([x, y]);  x = y = None
    tanh_1 = torch.tanh(cat_1);  cat_1 = None
    neg_1 = torch.neg(tanh_1)
    add = torch.add(tanh_1, neg_1);  tanh_1 = None
    return neg_1