原语库#

在这个例子中,将定义“复合”(composite)运算库。复合运算是定义为可调用函数的运算,这些函数在其实现中由多个其他运算组成。

复合运算允许您选择在什么抽象级别上解释/运算代码。我们演示了可以提供一个函数来内联这些函数,也可以使用自定义 Tracer 来自动内联这些函数。

组合运算对于向后端/变换公开更高级别的上下文,同时仍然保持在更细粒度级别检查内容的能力很有用。

import torch
from torch import fx

def sigmoid_lowp(x: torch.Tensor):
    x = x.float()
    x = x.sigmoid()
    return x.half()

wrap() 表示传入的函数应该始终被记录为 call_function 节点,而不是被跟踪。稍后,我们将看到如何做到:

a. 内联这样一个函数的实现; b. 定义一个跟踪器,自动跟踪这样一个函数

# primitive_library.py
fx.wrap(sigmoid_lowp)

同样:

# primitive_library.py
def add_lowp(a: torch.Tensor, b: torch.Tensor):
    a, b = a.float(), b.float()
    c = a + b
    return c.half()

torch.fx.wrap(add_lowp)

看看在使用这些函数的代码中进行符号跟踪时会发生什么

from primitive_library import sigmoid_lowp, add_lowp
class Foo(torch.nn.Module):
    def forward(self, x, y):
        x = sigmoid_lowp(x)
        y = sigmoid_lowp(y)
        return add_lowp(x, y)


traced = fx.symbolic_trace(Foo())
print(traced.code)
def forward(self, x, y):
    float_1 = x.float();  x = None
    sigmoid = float_1.sigmoid();  float_1 = None
    half = sigmoid.half();  sigmoid = None
    float_2 = y.float();  y = None
    sigmoid_1 = float_2.sigmoid();  float_2 = None
    half_1 = sigmoid_1.half();  sigmoid_1 = None
    float_3 = half.float();  half = None
    float_4 = half_1.float();  half_1 = None
    add = float_3 + float_4;  float_3 = float_4 = None
    half_2 = add.half();  add = None
    return half_2
    

注意 sigmoid_lowpadd_lowp 的调用出现在跟踪中;他们自身没有被追踪.

内联回调#

定义一个函数,允许在 graph 运算期间内联这些调用。

def inline_lowp_func(n : fx.Node):
    # If we find a call to a function in our "lowp" module, inline it
    if n.op == 'call_function' and n.target.__module__ == inline_lowp_func.__module__:
        # We want to insert the operations comprising the implementation of the
        # function before the function itself. Then, we can swap the output value
        # of the function call with the output value for its implementation nodes
        tracer = fx.proxy.GraphAppendingTracer(n.graph)
        with n.graph.inserting_before(n):
            # We can inline code by using `fx.Proxy` instances.
            # map_arg traverses all aggregate types and applies the given function
            # to Node instances in the data structure. In this case, we are applying
            # the fx.Proxy constructor.
            proxy_args = torch.fx.node.map_arg(n.args, lambda x: torch.fx.Proxy(x, tracer))
            proxy_kwargs = torch.fx.node.map_arg(n.kwargs, lambda x: torch.fx.Proxy(x, tracer))
            # Call the function itself with proxy arguments. This will emit
            # nodes in the graph corresponding to the operations in the im-
            # plementation of the function
            output_proxy = n.target(*proxy_args, **proxy_kwargs)
            # Now replace the original node's uses with the output node of
            # the implementation.
            node.replace_all_uses_with(output_proxy.node)
            # Delete the old node
            node.graph.erase_node(node)
for node in traced.graph.nodes:
    if node.op == 'call_function' and node.target is sigmoid_lowp:
        inline_lowp_func(node)

# 不要忘记在 Graph 运算之后重新编译
new_code = traced.recompile()
print(traced.code)
def forward(self, x, y):
    float_1 = x.float();  x = None
    sigmoid = float_1.sigmoid();  float_1 = None
    half = sigmoid.half();  sigmoid = None
    float_2 = y.float();  y = None
    sigmoid_1 = float_2.sigmoid();  float_2 = None
    half_1 = sigmoid_1.half();  sigmoid_1 = None
    float_3 = half.float();  half = None
    float_4 = half_1.float();  half_1 = None
    add = float_3 + float_4;  float_3 = float_4 = None
    half_2 = add.half();  add = None
    return half_2
    

此时,sigmoid_lowp 的实现已被替换为所有对该函数的调用。

跟踪期间的内联调用#

现在将定义自定义跟踪器,它可以有选择地动态内联对某些组合运算的调用。

f = Foo()

class InliningTracer(fx.Tracer):
    FNS_TO_INLINE = [add_lowp]

    def create_node(self, kind, target, args, kwargs, name=None, type_expr=None):
        if kind == 'call_function' and target in self.FNS_TO_INLINE:
            tracer = fx.proxy.GraphAppendingTracer(self.graph)
            # Trace through the implementation of the function rather than
            # create a node
            proxy_args = fx.node.map_arg(args, lambda x: torch.fx.Proxy(x, tracer))
            proxy_kwargs = fx.node.map_arg(kwargs, lambda x: torch.fx.Proxy(x, tracer))
            return target(*proxy_args, **proxy_kwargs).node
        else:
            return super().create_node(kind, target, args, kwargs, name, type_expr)


tracer = InliningTracer()
graph = tracer.trace(f)
module = torch.fx.GraphModule(f, graph)
print(module.code)
def forward(self, x, y):
    float_1 = x.float();  x = None
    sigmoid = float_1.sigmoid();  float_1 = None
    half = sigmoid.half();  sigmoid = None
    float_2 = y.float();  y = None
    sigmoid_1 = float_2.sigmoid();  float_2 = None
    half_1 = sigmoid_1.half();  sigmoid_1 = None
    float_3 = half.float();  half = None
    float_4 = half_1.float();  half_1 = None
    add = float_3 + float_4;  float_3 = float_4 = None
    half_2 = add.half();  add = None
    return half_2
    

正如你所看到的,add_lowp 的实现已经在使用我们的 InliningTracer 进行跟踪的过程中内联了。例如,这样的功能可以用于实现后端,该后端希望看到某些运算的低级形式,但希望看到另一些运算的高级形式。