符号追踪的局限性#

FX 使用符号跟踪系统(又称 符号执行)以可变换/可分析的形式捕获程序的语义。

系统是追踪的(tracing),因为它执行程序(实际上是 Module 或函数)来记录运算。它是符号的(symbolic),因为在执行过程中流经程序的数据不是真正的数据,而是符号(FX 术语中的 Proxy)。

尽管符号追踪适用于大多数神经网络代码,但它也有一些局限性。

动态流程控制#

符号追踪的主要限制是它目前不支持 动态控制流(dynamic control flow)。也就是说,循环或 if 语句的条件可能取决于程序的输入值。

比如:

import torch
from torch import fx

def func_to_trace(x):
    if x.sum() > 0:
        return torch.relu(x)
    else:
        return torch.neg(x)

traced = fx.symbolic_trace(func_to_trace)


"""
  <...>
  File "dyn.py", line 6, in func_to_trace
    if x.sum() > 0:
  File "pytorch/torch/fx/proxy.py", line 155, in __bool__
    return self.tracer.to_bool(self)
  File "pytorch/torch/fx/proxy.py", line 85, in to_bool
    raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""

if 语句的条件依赖于 x.sum() 的值,而 x.sum() 依赖于函数输入 x 的值。因为 x 可以改变(例如,如果你将新的输入张量传递给追踪函数),这就是 动态控制流。回溯遍历代码,向您显示这种情况发生的位置。

静态流程控制#

另一方面,支持所谓的 静态控制流。静态控制流是循环或 if 语句,其值不能在调用之间更改。通常,在 PyTorch 程序中,这种控制流用于基于超参数对模型的体系结构做出决策的代码。举个具体的例子:

import torch
from torch import fx

class MyModule(torch.nn.Module):
    def __init__(self, do_activation : bool = False):
        super().__init__()
        self.do_activation = do_activation
        self.linear = torch.nn.Linear(512, 512)

    def forward(self, x):
        x = self.linear(x)
        # 这个 if 语句就是所谓的静态控制流。
        # 它的条件不依赖于任何输入值
        if self.do_activation:
            x = torch.relu(x)
        return x

without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)
traced_without_activation = fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
def forward(self, x):
    linear = self.linear(x);  x = None
    return linear
    
traced_with_activation = fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
def forward(self, x):
    linear = self.linear(x);  x = None
    relu = torch.relu(linear);  linear = None
    return relu
    

if-语句 if self.do_activation 不依赖于任何函数输入,因此它是静态的。do_activation 可以被认为是超参数,具有该参数不同值的 MyModule 的不同实例的追踪具有不同的代码。这是符号跟踪支持的有效模式。

许多动态控制流的实例在语义上是静态控制流。这些实例可以通过移除对输入值的数据依赖来支持符号跟踪,例如将值移动到 Module 属性,或者在符号跟踪期间将具体值绑定到参数:

def f(x, flag):
    if flag: return x
    else: return x*2

fx.symbolic_trace(f) # Fails!

fx.symbolic_trace(f, concrete_args={'flag': True})

torch 函数#

FX 使用 __torch_function__ 作为拦截调用的机制(有关这方面的更多信息,请参阅技术概述)。一些函数,例如 Python 内置函数或数学模块中的函数,没有被 __torch_function__ 覆盖,但仍然希望在符号跟踪中捕获它们。例如:

import torch
import torch.fx
from math import sqrt

def normalize(x):
    """
    Normalize `x` by the size of the batch dimension
    """
    return x / sqrt(len(x))

# It's valid Python code
normalize(torch.rand(3, 4))

traced = torch.fx.symbolic_trace(normalize)
"""
  <...>
  File "sqrt.py", line 9, in normalize
    return x / sqrt(len(x))
  File "pytorch/torch/fx/proxy.py", line 161, in __len__
    raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""

这个错误告诉我们不支持内置函数 len()。可以使用 wrap() API 将这样的函数作为直接调用记录在跟踪中:

fx.wrap('len')
fx.wrap('sqrt')

traced = fx.symbolic_trace(normalize)

print(traced.code)
"""
import math
def forward(self, x):
    len_1 = len(x)
    sqrt_1 = math.sqrt(len_1);  len_1 = None
    truediv = x / sqrt_1;  x = sqrt_1 = None
    return truediv
"""

使用 Tracer 自定义追踪#

Tracer 类是 symbolic_trace() 实现的基础类。跟踪的行为可以通过子类化 Tracer 来定制,如下所示:

class MyCustomTracer(torch.fx.Tracer):
    """自定义追踪器"""
    ...


# 使用自定义跟踪程序来跟踪整个 module
class MyModule(torch.nn.Module):
    def forward(self, x):
        return torch.relu(x) + torch.ones(3, 4)

mod = MyModule()

# trace() 返回 Graph
traced_graph = MyCustomTracer().trace(mod)
# 包装到 GraphModule 中,使其可运行
traced = fx.GraphModule(mod, traced_graph)

叶模块#

叶模块(Leaf Module)是在符号跟踪中作为调用而不是被跟踪的模块。叶模块的默认集合是标准 torch.nn 模块实例。例如:

class MySpecialSubmodule(torch.nn.Module):
    def forward(self, x):
        return torch.neg(x)

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 4)
        self.submod = MySpecialSubmodule()

    def forward(self, x):
        return self.submod(self.linear(x))

traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
def forward(self, x):
    linear = self.linear(x);  x = None
    neg = torch.neg(linear);  linear = None
    return neg
    

linear 被保留为调用,但是 submod 被跟踪。这是因为默认的“叶模块”包含了所有标准的 torch.nn 的模块。

叶模块集可以通过重写 is_leaf_module() 来定制。

Miscellanea#

Tensor 构造函数(torch.zeros(), torch.ones(), torch.rand(), torch.randn(), torch.sparse_coo_tensor())目前不可追踪。

  • 可以使用确定性构造函数(zeros, ones),它们产生的值将作为常量嵌入到跟踪中。只有当这些构造函数的参数引用动态输入大小时,才会出现问题。在这种情况下, ones_like()zeros_like() 可能是可行的替代方法。

  • 非确定性构造函数(rand(), randn())将在跟踪中嵌入单个随机值。这可能不是预期的行为。解决办法是 使用 torch.fx.wrap() 包装。

    @torch.fx.wrap
    def torch_randn(x, shape):
        return torch.randn(shape)
    
    def f(x):
        return x + torch_randn(x, 5)
    fx.symbolic_trace(f)
    
  • 类型注解

    • Python 3 风格的类型注解(例如 func(x : torch.Tensor, y : int) -> torch.Tensor) 是受支持的,并将通过符号跟踪保存。

    • 目前不支持函数中局部名称的注解。

  • training flag 和子模块周围有问题

    • 当使用像 torch.nn.functional.dropout() 这样的函数时,训练参数通常被传递为 self.training。在 FX 跟踪过程中,这可能会作为常数值进行处理。

    
    
import torch
import torch.fx

class DropoutRepro(torch.nn.Module):
  def forward(self, x):
    return torch.nn.functional.dropout(x, training=self.training)


traced = torch.fx.symbolic_trace(DropoutRepro())
print(traced.code)
def forward(self, x):
    dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False);  x = None
    return dropout
    
traced.eval()

x = torch.randn(5, 3)
torch.testing.assert_allclose(traced(x), x)
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/LimitationsSymbolicTracing.ipynb Cell 18 in <cell line: 4>()
      <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/LimitationsSymbolicTracing.ipynb#X42sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a> traced.eval()
      <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/LimitationsSymbolicTracing.ipynb#X42sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a> x = torch.randn(5, 3)
----> <a href='vscode-notebook-cell://ssh-remote%2Bxin/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/LimitationsSymbolicTracing.ipynb#X42sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a> torch.testing.assert_allclose(traced(x), x)

File /media/pc/data/4tb/lxw/libs/anaconda3/envs/tvmx/lib/python3.10/site-packages/torch/testing/_deprecated.py:32, in warn_deprecated.<locals>.outer_wrapper.<locals>.inner_wrapper(*args, **kwargs)
     30 @functools.wraps(fn)
     31 def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
---> 32     return_value = fn(*args, **kwargs)
     33     tail = instructions(name, args, kwargs, return_value) if callable(instructions) else instructions
     34     msg = (head + tail).strip()

File /media/pc/data/4tb/lxw/libs/anaconda3/envs/tvmx/lib/python3.10/site-packages/torch/testing/_deprecated.py:80, in assert_allclose(actual, expected, rtol, atol, equal_nan, msg)
     77 if rtol is None and atol is None:
     78     rtol, atol = _get_default_rtol_and_atol(actual, expected)
---> 80 torch.testing.assert_close(
     81     actual,
     82     expected,
     83     rtol=rtol,
     84     atol=atol,
     85     equal_nan=equal_nan,
     86     check_device=True,
     87     check_dtype=False,
     88     check_stride=False,
     89     msg=msg or None,
     90 )

    [... skipping hidden 1 frame]

File /media/pc/data/4tb/lxw/libs/anaconda3/envs/tvmx/lib/python3.10/site-packages/torch/testing/_comparison.py:1095, in assert_equal(actual, expected, pair_types, sequence_types, mapping_types, msg, **options)
   1092     return
   1094 # TODO: compose all metas into one AssertionError
-> 1095 raise error_metas[0].to_error(msg)

AssertionError: Tensor-likes are not close!

Mismatched elements: 15 / 15 (100.0%)
Greatest absolute difference: 1.709273338317871 at index (4, 2) (up to 1e-05 allowed)
Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed)

但是,当使用标准的 Dropout 子模块时,training 标志将被封装(因为保留了 Module 对象模型)且可以更改。

class DropoutRepro2(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.drop = torch.nn.Dropout()

  def forward(self, x):
    return self.drop(x)

traced = torch.fx.symbolic_trace(DropoutRepro2())
print(traced.code)
traced.eval()

x = torch.randn(5, 3)
torch.testing.assert_close(traced(x), x)
def forward(self, x):
    drop = self.drop(x);  x = None
    return drop
    
  • 由于这种差异,可以考虑将与 training 标志动态交互的模块标记为叶模块。