自定义跟踪器#

ModulePathTracer#

将定义自定义的 Tracer 实例,对于每个记录的运算,也记下该运算起源于的模块的限定名。

from typing import Any, Callable, Dict, Optional, Tuple
import torch
from torch import fx
class ModulePathTracer(fx.Tracer):
    """
    ModulePathTracer 是 FX 跟踪器,对于每个运算,它还记录了运算起源于的模块的限定名。
    """
    
    # 正在跟踪的模块的当前限定名。
    # 顶级模块由空字符串表示。
    # 在进入 ``call_module`` 时更新,在退出 ``call_module`` 时恢复
    current_module_qualified_name : str = ''
    # 从 FX 节点到它起源模块的 qualname 的映射
    # 这在记录运算时由 `create_proxy` 记录
    node_to_originating_module : Dict[torch.fx.Node, str] = {}

    def call_module(self, m: torch.nn.Module, forward: Callable[..., Any],
                    args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any:
        """
        1. 存储调用者的限定名称以便稍后恢复
        2. 在 `current_module_qualified_name` 中安装(install)调用者的限定名,以供 `create_proxy` 检索。
        3. 委托到正常的 Tracer.call_module 方法
        4. 将调用者的限定名恢复到 current_module_qualified_name 中
        """
        old_qualname = self.current_module_qualified_name
        try:
            self.current_module_qualified_name = self.path_of_module(m)
            return super().call_module(m, forward, args, kwargs)
        finally:
            self.current_module_qualified_name = old_qualname

    def create_proxy(self, kind: str, target: torch.fx.node.Target, args: Tuple[Any, ...],
                     kwargs: Dict[str, Any], name: Optional[str] = None, type_expr: Optional[Any] = None):
        """
        Override of `Tracer.create_proxy`. This override intercepts the recording
        of every operation and stores away the current traced module's qualified
        name in `node_to_originating_module`
        """
        proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr)
        self.node_to_originating_module[proxy.node] = self.current_module_qualified_name
        return proxy
# Testing: let's see how this works on a torchvision ResNet18 model
import torchvision.models as models

# Model under test
rn18 = models.resnet18()

# Instantiate our ModulePathTracer and use that to trace our ResNet18
tracer = ModulePathTracer()
traced_rn18 = tracer.trace(rn18)

# Print (node, module qualified name) for every node in the Graph
for node in traced_rn18.nodes:
    module_qualname = tracer.node_to_originating_module.get(node)
    print('Node', node, 'is from module', module_qualname)
Node x is from module 
Node conv1 is from module conv1
Node bn1 is from module bn1
Node relu is from module relu
Node maxpool is from module maxpool
Node layer1_0_conv1 is from module layer1.0.conv1
Node layer1_0_bn1 is from module layer1.0.bn1
Node layer1_0_relu is from module layer1.0.relu
Node layer1_0_conv2 is from module layer1.0.conv2
Node layer1_0_bn2 is from module layer1.0.bn2
Node add is from module layer1.0
Node layer1_0_relu_1 is from module layer1.0.relu
Node layer1_1_conv1 is from module layer1.1.conv1
Node layer1_1_bn1 is from module layer1.1.bn1
Node layer1_1_relu is from module layer1.1.relu
Node layer1_1_conv2 is from module layer1.1.conv2
Node layer1_1_bn2 is from module layer1.1.bn2
Node add_1 is from module layer1.1
Node layer1_1_relu_1 is from module layer1.1.relu
Node layer2_0_conv1 is from module layer2.0.conv1
Node layer2_0_bn1 is from module layer2.0.bn1
Node layer2_0_relu is from module layer2.0.relu
Node layer2_0_conv2 is from module layer2.0.conv2
Node layer2_0_bn2 is from module layer2.0.bn2
Node layer2_0_downsample_0 is from module layer2.0.downsample.0
Node layer2_0_downsample_1 is from module layer2.0.downsample.1
Node add_2 is from module layer2.0
Node layer2_0_relu_1 is from module layer2.0.relu
Node layer2_1_conv1 is from module layer2.1.conv1
Node layer2_1_bn1 is from module layer2.1.bn1
Node layer2_1_relu is from module layer2.1.relu
Node layer2_1_conv2 is from module layer2.1.conv2
Node layer2_1_bn2 is from module layer2.1.bn2
Node add_3 is from module layer2.1
Node layer2_1_relu_1 is from module layer2.1.relu
Node layer3_0_conv1 is from module layer3.0.conv1
Node layer3_0_bn1 is from module layer3.0.bn1
Node layer3_0_relu is from module layer3.0.relu
Node layer3_0_conv2 is from module layer3.0.conv2
Node layer3_0_bn2 is from module layer3.0.bn2
Node layer3_0_downsample_0 is from module layer3.0.downsample.0
Node layer3_0_downsample_1 is from module layer3.0.downsample.1
Node add_4 is from module layer3.0
Node layer3_0_relu_1 is from module layer3.0.relu
Node layer3_1_conv1 is from module layer3.1.conv1
Node layer3_1_bn1 is from module layer3.1.bn1
Node layer3_1_relu is from module layer3.1.relu
Node layer3_1_conv2 is from module layer3.1.conv2
Node layer3_1_bn2 is from module layer3.1.bn2
Node add_5 is from module layer3.1
Node layer3_1_relu_1 is from module layer3.1.relu
Node layer4_0_conv1 is from module layer4.0.conv1
Node layer4_0_bn1 is from module layer4.0.bn1
Node layer4_0_relu is from module layer4.0.relu
Node layer4_0_conv2 is from module layer4.0.conv2
Node layer4_0_bn2 is from module layer4.0.bn2
Node layer4_0_downsample_0 is from module layer4.0.downsample.0
Node layer4_0_downsample_1 is from module layer4.0.downsample.1
Node add_6 is from module layer4.0
Node layer4_0_relu_1 is from module layer4.0.relu
Node layer4_1_conv1 is from module layer4.1.conv1
Node layer4_1_bn1 is from module layer4.1.bn1
Node layer4_1_relu is from module layer4.1.relu
Node layer4_1_conv2 is from module layer4.1.conv2
Node layer4_1_bn2 is from module layer4.1.bn2
Node add_7 is from module layer4.1
Node layer4_1_relu_1 is from module layer4.1.relu
Node avgpool is from module avgpool
Node flatten is from module 
Node fc is from module fc
Node output is from module None

追踪全部的 ReLU 子模块#

在符号跟踪过程中,跟踪一些子模块并记录它们的组成运算;其他子模块在 IR 中显示为原子 “call_module” 节点。后一类中的模块称为“叶模块”。默认情况下,PyTorch 标准库(torch.nn)中的所有模块都是叶模块。可以通过创建自定义跟踪器并重写 is_leaf_module 来改变这一点。

import torch
from torch import fx

class M1(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        return self.relu(x)

default_traced: fx.GraphModule = fx.symbolic_trace(M1())
default_traced
M1(
  (relu): ReLU()
)
default_traced.graph.print_tabular()
opcode       name    target    args     kwargs
-----------  ------  --------  -------  --------
placeholder  x       x         ()       {}
call_module  relu    relu      (x,)     {}
output       output  output    (relu,)  {}

更改 torch.nn.ReLU 的默认行为:

class LowerReluTracer(fx.Tracer):
    def is_leaf_module(self, m: torch.nn.Module, qualname: str):
        if isinstance(m, torch.nn.ReLU):
            return False
        return super().is_leaf_module(m, qualname)
lower_relu_tracer = LowerReluTracer()
custom_traced_graph: fx.Graph = lower_relu_tracer.trace(M1())
custom_traced_graph.print_tabular()
opcode         name    target                             args     kwargs
-------------  ------  ---------------------------------  -------  ------------------
placeholder    x       x                                  ()       {}
call_function  relu    <function relu at 0x7ff559aac160>  (x,)     {'inplace': False}
output         output  output                             (relu,)  {}

为每个节点添加额外的属性#

在这里,将重写 create_node,以便在创建每个 Node 时向其添加新属性

class M2(torch.nn.Module):
    def forward(self, a, b):
        return a + b

class TaggingTracer(fx.Tracer):
    def create_node(self, kind : str, target:  str | Callable,
                    args: Tuple[Any], kwargs: Dict[str, Any], name: str | None=None,
                    type_expr: Any | None=None) -> fx.Node:
        n = super().create_node(kind, target, args, kwargs, name)
        n.tag = "foo"
        return n

custom_traced_graph: fx.Graph = TaggingTracer().trace(M2())

def assert_all_nodes_have_tags(g: fx.Graph) -> bool:
    for n in g.nodes:
        if not hasattr(n, "tag") or not n.tag == "foo":
            return False
    return True

# Prints "True"
print(assert_all_nodes_have_tags(custom_traced_graph))
True

内联函数到现有的 Graph#

您可能希望内联函数的原因是避开 FX 的默认跟踪行为。例如,除非您已经定义了自定义跟踪器,否则 symbolic_trace 的开箱即用实现将导致引用 torch.nn 模块实例的显式 call_module 调用,而不是被跟踪。假设这种行为几乎是你所需要的;唯一的问题是,您希望用函数的内联跟踪来替换单个模块调用。创建自定义跟踪器的工作量太大了。相反,您可以使用 代理 来完成此任务。

下面的代码演示了如何使用 Proxy 跟踪模块并将其内联到现有的 Graph 中。我们将跟踪 Graph,然后遍历它的节点,直到找到用内联跟踪替换 call_module 节点的正确位置。在这一点上,我们将从节点的 argskwargs 创建代理。最后,我们将调用要用那些代理替换的函数——从本质上讲,这将“跟踪”该函数。最后,我们将把调用的结果插入到我们的 Graph 中。(最后一步将自动内联函数。)

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        return self.relu(x) + 1.0

符号跟踪 M 实例。跟踪后, self.relu 被表示为 call_module 节点。

m = fx.symbolic_trace(M())

torch.nn.ReLU graph 中插入节点,取代原来的调用 self.relu.

创建指向原始 Graph 的图附加跟踪程序

tracer = fx.proxy.GraphAppendingTracer(m.graph)
for node in m.graph.nodes:
    # Find `call_module` Node in `m` that corresponds to `self.relu`.
    # This is the Node we want to swap out for an inlined version of the
    # same call
    if (node.op, node.target) == ("call_module", "relu"):
        with m.graph.inserting_before(node):
            # Create a Proxy from each Node in the current Node's
            # args/kwargs
            proxy_args = fx.map_arg(node.args, lambda n: fx.Proxy(n, tracer))
            proxy_kwargs = fx.map_arg(node.kwargs, lambda n: fx.Proxy(n, tracer))
            # Call `m.relu` with the newly-created Proxy arguments.
            # `m.relu` is the generic version of the function; by
            # calling it with Proxies created from Nodes in `m`, we're
            # emitting Nodes that reference exiting values in the IR.
            # The result of this call is another Proxy, which we can
            # hook into our existing Graph to complete the function
            # inlining.
            proxy_output = m.relu(*proxy_args, **proxy_kwargs)
            # Replace the relu `call_module` node with the inlined
            # version of the function
            node.replace_all_uses_with(proxy_output.node)
            # Make sure that the old relu Node is erased
            m.graph.erase_node(node)

FX 计算 反函数#

import torch
from torch import fx

逆映射是接受函数 f(x) 并返回函数 g 使 f(g(x)) == x 的映射。例如,由于 log(exp(x)) == x,所以 explog 是逆映射。

invert_mapping = {}
def add_inverse(a, b):
    invert_mapping[a] = b
    invert_mapping[b] = a
inverses = [
    (torch.sin, torch.arcsin),
    (torch.cos, torch.arccos),
    (torch.tan, torch.arctan),
    (torch.exp, torch.log),
]
for a, b in inverses:
    add_inverse(a, b)

一般的策略是 backward walk graph,将每个节点变换为它的逆(inverse)节点。

为此,我们交换函数的输出和输入,然后在 invert_mapping 中查找它的逆函数。注意,此变换假设所有运算只接受一个输入并返回一个输出。

def invert(model: torch.nn.Module) -> torch.nn.Module:
    fx_model = fx.symbolic_trace(model)
    new_graph = fx.Graph()  # 建立新的 graph
    env = {}
    for node in reversed(fx_model.graph.nodes):
        if node.op == 'call_function':
            # 在新 graph 中创建具有逆函数的节点,并传递 `env[node.name]` (即之前的输出节点)作为输入。
            new_node = new_graph.call_function(invert_mapping[node.target], 
                                               (env[node.name],))
            env[node.args[0].name] = new_node
        elif node.op == 'output':
            # 将 output 转换为输入 placeholder
            new_node = new_graph.placeholder(node.name)
            env[node.args[0].name] = new_node
        elif node.op == 'placeholder':
            # 将输入 placeholder 转换为 output
            new_graph.output(env[node.name])
        else:
            raise RuntimeError("Not implemented")

    new_graph.lint()
    return fx.GraphModule(fx_model, new_graph)


def f(x):
    return torch.exp(torch.tan(x))

res = invert(f)
print(res.code)

print(f(res((torch.arange(5) + 1))))  # [1., 2., 3., 4, 5.]
def forward(self, output):
    log = torch.log(output);  output = None
    arctan = torch.arctan(log);  log = None
    return arctan
    
tensor([1., 2., 3., 4., 5.])