动态包装计算图输出#

下面的代码演示了如何根据运行时指定的参数更改现有的 Graph。我们将让用户从预定义的 Enum 列表中指定激活函数,然后对其进行符号跟踪。接下来,我们将从图中的最后一个运算创建 Proxy。我们将使用这个代理调用跟踪的激活函数,并将调用中的 output 节点插入到我们的图中。(最后一步将自动内联整个跟踪函数。)

from enum import Enum, auto
import torch
from torch import fx

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        y = torch.cat([x, y])
        return y

# 符号追踪 `M` 实例
traced = fx.symbolic_trace(M())

选择激活函数:

class ActivationFunction(Enum):
    RELU = auto()
    LEAKY_RELU = auto()
    PRELU = auto()

将激活函数名称映射到它们的实现:

activation_functions = {
    ActivationFunction.RELU: torch.nn.ReLU(),
    ActivationFunction.LEAKY_RELU: torch.nn.LeakyReLU(),
    ActivationFunction.PRELU: torch.nn.PReLU(),
}

def wrap_in_activation_function(m: fx.GraphModule, fn: ActivationFunction) -> fx.GraphModule:
    # Get output node
    output_node: Node|None = None
    for n in reversed(m.graph.nodes):
        if n.op == "output":
            output_node = n
            break
    assert output_node

    # Get the actual output (the "input" of the output node). This is
    # the Node we want to wrap in a user-specified activation function
    assert len(output_node.all_input_nodes) == 1
    wrap_node = output_node.all_input_nodes[0]

    # Wrap the actual output in a Proxy
    wrap_proxy = fx.Proxy(wrap_node)

    # Get the implementation of the specified activation function and
    # symbolically trace it
    fn_impl = activation_functions[fn]
    fn_impl_traced = fx.symbolic_trace(fn_impl)

    # Call the specified activation function using the Proxy wrapper for
    # `output_op`. The result of this call is another Proxy, which we
    # can hook into our existing Graph.
    with traced.graph.inserting_after(wrap_node):
        fn_impl_output_node = fn_impl_traced(wrap_proxy)
        new_args = (fn_impl_output_node.node,)
        output_node.args = new_args

    m.recompile()

测试:

x, y = torch.randn(5, 3), torch.randn(5, 3)
orig_output = traced(x, y)

wrap_in_activation_function(traced, ActivationFunction.LEAKY_RELU)
new_output = traced(x, y)

torch.testing.assert_close(new_output, torch.nn.LeakyReLU()(orig_output))