调试#

通常在创作变换的过程中,我们的代码并不完全正确。在这种情况下,可能需要进行一些调试。关键是 backwards 工作:首先,检查调用生成的 module 的结果,以证明或否定正确性。然后,检查和调试生成的代码。然后,调试导致生成代码的变换过程。

变换创作中的常见陷阱#

不确定的 set 迭代顺序。在 Python 中,设置的数据类型是无序的。例如,使用 set 来包含节点等对象的集合可能会导致意外的不确定性。一个例子是迭代一组节点,将它们插入到图中。因为设置的数据类型是无序的,输出程序中运算的顺序将是不确定的,并且可以在程序调用之间更改。推荐的替代方法是使用 dict 数据类型,这是 Python 3.7(以及 cPython 3.6)开始按照插入顺序排序。通过将要重复数据删除的值存储在 dict 的键中,dict 可以等价地用于 set

检查 module 的正确性#

因为大多数深度学习 module 的输出都是由浮点 torch.Tensor 实例组成,检查两个 torch.nn.Module 结果之间的等价性不像做简单的相等性检查那样直接。为了激发这个想法,举个例子(RuntimeError:有多个值的张量的布尔值不明确):

import torch
import torch.fx
import torchvision.models as models


def transform(m : torch.nn.Module) -> torch.nn.Module:
    gm = torch.fx.symbolic_trace(m)

    # Imagine we're doing some transforms here
    # <...>

    gm.recompile()

    return gm

resnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)

input_image = torch.randn(5, 3, 224, 224)

assert resnet18(input_image) == transformed_resnet18(input_image)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/Debugging.ipynb Cell 2 in <cell line: 21>()
     <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/Debugging.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=16'>17</a> transformed_resnet18 = transform(resnet18)
     <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/Debugging.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=18'>19</a> input_image = torch.randn(5, 3, 224, 224)
---> <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/Debugging.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=20'>21</a> assert resnet18(input_image) == transformed_resnet18(input_image)

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

在这里,尝试用 == 运算符检查两个深度学习模型的值是否相等。然而,由于运算符返回的是张量而不是 bool 值的问题,而且由于浮点值的比较应该使用误差边界(或 epsilon)来解释浮点运算的非交换性,这两个问题都没有很好地定义。可以使用 torch.allclose(),它会考虑到相对和绝对公差阈值的近似比较:

assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))

与参考实现相比,这是工具箱中检查变换模块行为是否如期望的那样的第一个工具。

调试生成的代码#

因为 FX 在 torch.fx.GraphModule 上生成 forward() 函数,所以使用传统的调试技术(如 print 语句或 pdb)就不那么直接了。幸运的是,有几种技术可以用来调试生成的代码。

使用 pdb#

调用 pdb 进入正在运行的程序。尽管表示 torch.fx.Graph 的代码不在任何源文件中,但是当调用 forward 传递时,仍然可以使用 pdb 手动进入它。

import torch
from torch import fx
import torchvision.models as models

def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph = tracer_class().trace(inp)
    # Transformation logic here
    # <...>

    # Return new Module
    return fx.GraphModule(inp, graph)

my_module = models.resnet18()
my_module_transformed = my_pass(my_module)

input_value = torch.randn(5, 3, 224, 224)

# When this line is executed at runtime, we will be dropped into an
# interactive `pdb` prompt. We can use the `step` or `s` command to
# step into the execution of the next line
import pdb; pdb.set_trace()

my_module_transformed(input_value)
--Return--
None
> /tmp/ipykernel_2297333/4158250709.py(21)<cell line: 21>()
     19 # interactive `pdb` prompt. We can use the `step` or `s` command to
     20 # step into the execution of the next line
---> 21 import pdb; pdb.set_trace()
     22 
     23 my_module_transformed(input_value)

打印生成代码#

如果您想要多次运行相同的代码,那么使用 pdb 逐步找到正确的代码可能有点乏味。在这种情况下,一种方法是简单地将生成的 forward 传递复制粘贴到代码中,并从那里检查它。

# Assume that `traced` is a GraphModule that has undergone some
# number of transforms

# Copy this code for later
print(traced)
# Print the code generated from symbolic tracing. This outputs:
"""
def forward(self, y):
    x = self.x
    add_1 = x + y;  x = y = None
    return add_1
"""

# Subclass the original Module
class SubclassM(M):
    def __init__(self):
        super().__init__()

    # Paste the generated `forward` function (the one we printed and
    # copied above) here
    def forward(self, y):
        x = self.x
        add_1 = x + y;  x = y = None
        return add_1

# Create an instance of the original, untraced Module. Then, create an
# instance of the Module with the copied `forward` function. We can
# now compare the output of both the original and the traced version.
pre_trace = M()
post_trace = SubclassM()

使用 to_folder() 函数#

to_folder()GraphModule 中的方法,它允许你将生成的 FX 代码转储到文件夹中。尽管像打印生成的代码那样,将 forward 传递复制到代码中通常就足够了,但是使用 to_folder() 检查模块和参数可能更容易。

m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()

在运行上面的示例之后,可以查看 foo/module.py 中的代码,并根据需要修改它(例如添加 print 语句或使用 pdb),以调试生成的代码。

调试变换#

既然已经确定了变换正在创建不正确的代码,现在是调试变换本身的时候了。

# Sample Module
class M(torch.nn.Module):
    def forward(self, x, y):
        return x + y

# Create an instance of `M`
m = M()

# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a
# GraphModule, so we aren't showing any sample transforms for the
# sake of brevity.
traced = symbolic_trace(m)

# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):
    add = x + y;  x = y = None
    return add
"""

# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():
    %x : [#users=1] = placeholder[target=x]
    %y : [#users=1] = placeholder[target=y]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
    return add
"""

# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode         name    target                   args    kwargs
-------------  ------  -----------------------  ------  --------
placeholder    x       x                        ()      {}
placeholder    y       y                        ()      {}
call_function  add     <built-in function add>  (x, y)  {}
output         output  output                   (add,)  {}
"""

使用上面的实用函数,可以在应用变换之前和之后比较跟踪的 torch.nn.Module

抛开上面的例子,考虑下面的代码:

# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
    # Get the Graph from our traced Module
    g = tracer_class().trace(module)

    """
    Transformations on `g` go here
    """

    return fx.GraphModule(module, g)

# Transform the Graph
transformed = transform_graph(traced)

# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)

使用上面的例子,假设对 print(tracing) 的调用告诉我们变换中有一个错误。希望使用调试器找到哪里出了问题。可以通过中断 `transform_graph(已跟踪),然后按s“进入”对transform_graph(已跟踪)的调用来查看转换过程中发生了什么。