动态包装计算图输出
动态包装计算图输出#
下面的代码演示了如何根据运行时指定的参数更改现有的 Graph
。我们将让用户从预定义的 Enum 列表中指定激活函数,然后对其进行符号跟踪。接下来,我们将从图中的最后一个运算创建 Proxy
。我们将使用这个代理调用跟踪的激活函数,并将调用中的 output
节点插入到我们的图中。(最后一步将自动内联整个跟踪函数。)
from enum import Enum, auto
import torch
from torch import fx
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
y = torch.cat([x, y])
return y
# 符号追踪 `M` 实例
traced = fx.symbolic_trace(M())
选择激活函数:
class ActivationFunction(Enum):
RELU = auto()
LEAKY_RELU = auto()
PRELU = auto()
将激活函数名称映射到它们的实现:
activation_functions = {
ActivationFunction.RELU: torch.nn.ReLU(),
ActivationFunction.LEAKY_RELU: torch.nn.LeakyReLU(),
ActivationFunction.PRELU: torch.nn.PReLU(),
}
def wrap_in_activation_function(m: fx.GraphModule, fn: ActivationFunction) -> fx.GraphModule:
# Get output node
output_node: Node|None = None
for n in reversed(m.graph.nodes):
if n.op == "output":
output_node = n
break
assert output_node
# Get the actual output (the "input" of the output node). This is
# the Node we want to wrap in a user-specified activation function
assert len(output_node.all_input_nodes) == 1
wrap_node = output_node.all_input_nodes[0]
# Wrap the actual output in a Proxy
wrap_proxy = fx.Proxy(wrap_node)
# Get the implementation of the specified activation function and
# symbolically trace it
fn_impl = activation_functions[fn]
fn_impl_traced = fx.symbolic_trace(fn_impl)
# Call the specified activation function using the Proxy wrapper for
# `output_op`. The result of this call is another Proxy, which we
# can hook into our existing Graph.
with traced.graph.inserting_after(wrap_node):
fn_impl_output_node = fn_impl_traced(wrap_proxy)
new_args = (fn_impl_output_node.node,)
output_node.args = new_args
m.recompile()
测试:
x, y = torch.randn(5, 3), torch.randn(5, 3)
orig_output = traced(x, y)
wrap_in_activation_function(traced, ActivationFunction.LEAKY_RELU)
new_output = traced(x, y)
torch.testing.assert_close(new_output, torch.nn.LeakyReLU()(orig_output))