FX 算子替换#

  1. 遍历 GraphModuleGraph 中的所有 Node

  2. 确定是否应该替换当前 Node (建议:匹配节点的 target 属性)。

  3. 创建替换 Node 并将其添加到 Graph 中。

  4. 使用 FX 内置的 replace_all_uses_with() 替换当前 Node 的所有使用。

  5. Graph 中删除旧 Node

  6. GraphModule 上调用 recompile。这会更新生成的 Python 代码,以反射(reflect)新的 Graph 状态。

下面的代码演示了用按位 AND 替换任意加法实例的示例。

要检查 Graph 在运算替换期间的演变情况,可以在要检查的行之后添加语句 print(traced.graph)。 或者,调用 traced.graph.print_tabular() 以查看表格格式的 IR。

import torch
from torch import fx
import operator
# module 样例
class M(torch.nn.Module):
    def forward(self, x, y):
        return x + y, torch.add(x, y), x.add(y)

以符号方式跟踪模块的实例:

traced = fx.symbolic_trace(M())

有几种不同的表示加法的方法:

patterns = set([operator.add, torch.add, "add"])

# 遍历 Graph 中全部节点
for n in traced.graph.nodes:
    # 如果目标匹配其中一个模式
    if any(n.target == pattern for pattern in patterns):
        # 设置插入点,添加新节点,用新节点替换所有 `n` 的用法
        with traced.graph.inserting_after(n):
            new_node = traced.graph.call_function(torch.bitwise_and, n.args, n.kwargs)
            n.replace_all_uses_with(new_node)
        # 移除 graph 中旧的节点
        traced.graph.erase_node(n)

# 不用忘记 recompile!
new_code = traced.recompile()
print(new_code.src)
def forward(self, x, y):
    bitwise_and = torch.bitwise_and(x, y)
    bitwise_and_1 = torch.bitwise_and(x, y)
    bitwise_and_2 = torch.bitwise_and(x, y);  x = y = None
    return (bitwise_and, bitwise_and_1, bitwise_and_2)