操控 Graph
Contents
操控 Graph
#
构建新 Graph
的一种方法是直接操控旧图。为了帮助实现这一点,可以简单地从符号跟踪中获取 Graph
并对其进行修改。例如,假设希望用 torch.mul()
调用替换 torch.add()
调用。
import torch
from torch import fx
# 样例模块
class M(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y)
查看节点信息:
m = M()
gm: torch.fx.GraphModule = torch.fx.symbolic_trace(m)
for node in gm.graph.nodes:
print(node, node.op, node.target)
x placeholder x
y placeholder y
add call_function <built-in method add of type object at 0x7f5bb4da7200>
output output output
def transform(m: torch.nn.Module,
tracer_class: type = fx.Tracer) -> torch.nn.Module:
graph: fx.Graph = tracer_class().trace(m)
# FX 将其 Graph 表示为节点的有序列表,因此可以遍历它们。
for node in graph.nodes:
# 检查是否正在调用函数(例如:torch.add)
if node.op == 'call_function':
# target 属性是 call_function 调用的函数。
if node.target == torch.add:
node.target = torch.mul
graph.lint() # 做一些检查,以确保 Graph 是格式良好的。
return fx.GraphModule(m, graph)
或者使用更简洁的写法:
m = M()
traced: torch.fx.GraphModule = torch.fx.symbolic_trace(m)
for node in traced.graph.nodes:
if node.op == 'call_function':
# target 属性是 call_function 调用的函数。
if node.target == torch.add:
node.target = torch.mul
traced.graph.lint() # 做一些检查,以确保 Graph 是格式良好的。
traced.graph.print_tabular()
opcode name target args kwargs
------------- ------ ------------------------------------------------------ ------ --------
placeholder x x () {}
placeholder y y () {}
call_function add <built-in method mul of type object at 0x7f5bb4da7200> (x, y) {}
output output output (add,) {}
还可以进行更复杂的 Graph
重写,比如删除或追加节点。为了帮助完成这些变换,FX 提供了变换 Graph
的实用函数。下面是使用这些 API 附加 relu()
调用的示例。
def inserting_after(node, new_node=torch.relu):
"""指定插入点,并在此范围内添加到 Graph 中的任何节点都将插入到 `node` 之后"""
with traced.graph.inserting_after(node):
# 插入新的 `call_function` 节点调用 `torch.relu``
new_node = traced.graph.call_function(new_node, args=(node,))
# 希望所有使用 `node` 值的节点后添加 `relu` 回调
# 使用 `replace_all_uses_with` API 来做到这一点。
node.replace_all_uses_with(new_node)
对于仅由替换组成的简单变换,还可以使用 torch.fx.subgraph_rewriter
。
replace_pattern()
重写子图#
FX 在直接 Graph
操作的基础上还提供了另一个自动化级别。replace_pattern()
API 本质上是编辑 Graph
的“查找/替换”工具。它允许您指定 pattern
和 replacement
,它将跟踪这些函数,在 pattern
graph 中查找运算组的实例,并用 replacement
graph 的副本替换这些实例。这有助于极大地自动化繁琐的 graph 操作代码,随着变换变得更加复杂,这些代码可能会变得笨拙。
import torch
from torch.fx import symbolic_trace, subgraph_rewriter
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, w1, w2):
m1 = torch.cat([w1, w2]).sum()
m2 = torch.cat([w1, w2]).sum()
return x + torch.max(m1) + torch.max(m2)
def pattern(w1, w2):
return torch.cat([w1, w2]).sum()
def replacement(w1, w2):
return torch.stack([w1, w2])
traced_module = symbolic_trace(M())
subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
[Match(anchor=max_1, nodes_map={output: max_1, sum_1: sum_1, cat: cat, w1: w1, w2: w2}),
Match(anchor=max_2, nodes_map={output: max_2, sum_1: sum_2, cat: cat_1, w1: w1, w2: w2})]
Proxy/Retracing#
另一种操作 Graph
的方法是重用符号跟踪中使用的 Proxy
机制。例如,假设想要编写一个变换,将 PyTorch 函数分解为更小的运算。它将把每个 F.relu(x)
调用变换为 (x > 0) * x
。一种可能是执行必要的 graph 重写,在 F.relu
之后插入比较和乘法,然后清理原来的 F.relu
。但是,可以通过使用 Proxy
对象自动地将运算记录到 Graph
中来自动化这个过程。
要使用此方法,将希望插入的运算编写为常规 PyTorch 代码,并使用 Proxy
对象作为参数调用该代码。这些代理对象将捕获对它们执行的操作,并将它们附加到 Graph
中。
from torch import fx
from torch.nn import functional as F
# 注意,这个分解(decomposition)规则可以理解为普通的Python
def relu_decomposition(x):
return (x > 0) * x
decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition
def decompose(model: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
"""
将 `model` 分解为更小的复合运算。
目前,它只支持将 ReLU 分解为它的数学定义:(x > 0) * x
"""
graph : fx.Graph = tracer_class().trace(model)
new_graph = fx.Graph()
env = {}
tracer = fx.proxy.GraphAppendingTracer(graph)
for node in graph.nodes:
if node.op == 'call_function' and node.target in decomposition_rules:
# 通过使用代理包装参数,可以分派到适当的分解规则,
# 并通过符号跟踪隐式地将其添加到 Graph 中。
proxy_args = [fx.Proxy(env[x.name], tracer)
if isinstance(x, fx.Node) else x for x in node.args]
output_proxy = decomposition_rules[node.target](*proxy_args)
# 对 `Proxy` 的运算总是产生新的 `Proxy`,分解规则的返回值也不例外。
# 需要从 `Proxy` 中提取底层的 `Node`,以便在此变换的后续迭代中使用它。
new_node = output_proxy.node
env[node.name] = new_node
else:
# 默认情况:没有此节点的分解规则,所以只需要将它复制到新的 Graph 中。
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
return fx.GraphModule(model, new_graph)
除了避免显式的 Graph
操作之外,使用 Proxy
还允许将重写规则指定为原生 Python 代码。对于需要大量重写规则的变换(如 vmap 或 grad),这通常可以提高规则的可读性和可维护性。注意,在调用 Proxy
时,还传递了指向底层变量 graph 的跟踪器。如果 graph 中的操作是 n-ary 的(例如 add 是二进制算子),那么调用 Proxy
不会创建 graph 跟踪器的多个实例,这会导致意外的运行时错误。推荐这种使用 Proxy
的方法,特别是当底层算子不能被安全地假定为一元的时候。
如何使用代理对象创建计算图#
可以直接在原始节点周围创建代理对象。这可用于创建独立于符号跟踪的 Graph
。
下面的代码演示了如何使用带有原始节点的代理将运算附加到新 Graph
。将创建两个参数( x
和 y
),对这些参数执行一些运算,然后将创建的所有内容添加到新的 Graph
中。然后将把这个 Graph
包装到 {class}
~torch.fx.GraphModule` 中。这样做会创建 nn.Module
的可运行实例。
创建独立于符号跟踪的计算图
graph = fx.Graph()
tracer = fx.proxy.GraphAppendingTracer(graph)
创建输入节点:
raw1 = graph.placeholder('x')
raw2 = graph.placeholder('y')
使用原始节点和图的默认跟踪器初始化代理
y = fx.Proxy(raw1, tracer)
z = fx.Proxy(raw2, tracer)
生成其他运算:
a = torch.cat([y, z])
b = torch.tanh(a)
c = torch.neg(b)
z = torch.add(b, c)
创建新的输出节点并将其添加到图中。通过这样做,图将包含刚刚创建的所有节点(因为它们都链接到输出节点).
graph.output(c.node)
output
将创建的图包装到 GraphModule
中,以获得最终的、可运行的 Module
的实例
mod = fx.GraphModule(torch.nn.Module(), graph)
print(mod.code)
def forward(self, x, y):
cat = torch.cat([x, y])
tanh = torch.tanh(cat); cat = None
neg = torch.neg(tanh); tanh = None
cat_1 = torch.cat([x, y]); x = y = None
tanh_1 = torch.tanh(cat_1); cat_1 = None
neg_1 = torch.neg(tanh_1)
add = torch.add(tanh_1, neg_1); tanh_1 = None
return neg_1