解释器模式
解释器模式#
FX 中一个有用的代码组织模式是循环遍历 Graph
中的所有 Node
并执行它们。这可以用于一些事情,包括对流经 Graph
的值的运行时分析,或者通过使用 Proxy
进行重跟踪的代码变换。例如,假设想要运行 GraphModule
并记录 Tensor
shape 和节点上的 dtype 属性,就像我们在运行时看到的那样。它可能看起来像:
import torch
import torch.fx
from torch.fx.node import Node
from typing import Dict
class ShapeProp:
"""
Shape 传播。这个类接受 `GraphModule`。
然后,使用给定的参数逐个节点地执行 `GraphModule` 的 `propagate` 方法。
当每个运算执行时,ShapeProp 类存储每个运算的输出值 `Node` 的属性 `shape` 和 `dtype`。
"""
def __init__(self, mod):
self.mod = mod
self.graph = mod.graph
self.modules = dict(self.mod.named_modules())
def propagate(self, *args):
args_iter = iter(args)
env : Dict[str, Node] = {}
def load_arg(a):
return fx.graph.map_arg(a, lambda n: env[n.name])
def fetch_attr(target : str):
target_atoms = target.split('.')
attr_itr = self.mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
for node in self.graph.nodes:
if node.op == 'placeholder':
result = next(args_iter)
elif node.op == 'get_attr':
result = fetch_attr(node.target)
elif node.op == 'call_function':
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
elif node.op == 'call_method':
self_obj, *args = load_arg(node.args)
kwargs = load_arg(node.kwargs)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == 'call_module':
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
# 这是唯一专门用于 shape 传播的代码。
# 你可以删除 `if` 分支,它就变成了通用的 GraphModule 解释器。
if isinstance(result, torch.Tensor):
node.shape = result.shape
node.dtype = result.dtype
env[node.name] = result
return load_arg(self.graph.result)
正如您所看到的,完整的 FX 解释器(interpreter)并不复杂,但它可能非常有用。为了方便使用这种模式,提供了 Interpreter
类,它以一种可以通过方法重写来重写解释器执行的某些方面的方式包含了上述逻辑。
除了执行运算之外,还可以通过解释器提供 Proxy
值来生成新的 Graph
。类似地,提供 Transformer
类来包含此模式。Transformer
的行为类似于 Interpreter
,但不是调用 run
方法从模块中获取具体的输出值,而是调用 torch.fx.Transformer.transform()
方法来返回新的 GraphModule
,它服从于作为覆盖方法安装的任何变换规则。