解释器模式#

FX 中一个有用的代码组织模式是循环遍历 Graph 中的所有 Node 并执行它们。这可以用于一些事情,包括对流经 Graph 的值的运行时分析,或者通过使用 Proxy 进行重跟踪的代码变换。例如,假设想要运行 GraphModule 并记录 Tensor shape 和节点上的 dtype 属性,就像我们在运行时看到的那样。它可能看起来像:

import torch
import torch.fx
from torch.fx.node import Node

from typing import Dict

class ShapeProp:
    """
    Shape 传播。这个类接受 `GraphModule`。
    然后,使用给定的参数逐个节点地执行 `GraphModule` 的 `propagate` 方法。
    当每个运算执行时,ShapeProp 类存储每个运算的输出值 `Node` 的属性 `shape` 和 `dtype`。
    """
    def __init__(self, mod):
        self.mod = mod
        self.graph = mod.graph
        self.modules = dict(self.mod.named_modules())

    def propagate(self, *args):
        args_iter = iter(args)
        env : Dict[str, Node] = {}

        def load_arg(a):
            return fx.graph.map_arg(a, lambda n: env[n.name])

        def fetch_attr(target : str):
            target_atoms = target.split('.')
            attr_itr = self.mod
            for i, atom in enumerate(target_atoms):
                if not hasattr(attr_itr, atom):
                    raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
                attr_itr = getattr(attr_itr, atom)
            return attr_itr

        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':
                result = fetch_attr(node.target)
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
            elif node.op == 'call_method':
                self_obj, *args = load_arg(node.args)
                kwargs = load_arg(node.kwargs)
                result = getattr(self_obj, node.target)(*args, **kwargs)
            elif node.op == 'call_module':
                result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
                
            # 这是唯一专门用于 shape 传播的代码。
            # 你可以删除 `if` 分支,它就变成了通用的 GraphModule 解释器。
            if isinstance(result, torch.Tensor):
                node.shape = result.shape
                node.dtype = result.dtype

            env[node.name] = result

        return load_arg(self.graph.result)

正如您所看到的,完整的 FX 解释器(interpreter)并不复杂,但它可能非常有用。为了方便使用这种模式,提供了 Interpreter 类,它以一种可以通过方法重写来重写解释器执行的某些方面的方式包含了上述逻辑。

除了执行运算之外,还可以通过解释器提供 Proxy 值来生成新的 Graph。类似地,提供 Transformer 类来包含此模式。Transformer 的行为类似于 Interpreter,但不是调用 run 方法从模块中获取具体的输出值,而是调用 torch.fx.Transformer.transform() 方法来返回新的 GraphModule,它服从于作为覆盖方法安装的任何变换规则。