__torch_dispatch__ 是什么#

参考:what-and-why-is-torch-dispatch & lets-talk-about-the-pytorch-dispatcher

简而言之:__torch_dispatch__ 允许你利用 dispatcher 的强大功能任意扩展 PyTorch,但现在是从 Python 中实现。这有望为 PyTorch 带来全新的灵活性,全部都在 Python 中实现。

PyTorch 的核心是什么?(剧透:dispatcher)#

从宏观上看,PyTorch 做了两件事。

  • 根据输入,确定要运行的合适内核,以及是 CUDA 实现还是 CPU 实现。

  • 根据输入,在自动求导图中注册合适的事物。

这两件事使得 PyTorch 从“numpy”变成了“支持 CUDA 和自动求导的 numpy”。最关键的是,这两件事都是通过调度器(dispatcher)在 PyTorch 中实现的。

核心来说,调度器是一个系统,根据输入的属性决定调用哪个函数。要了解更多,建议阅读 Edward Yang 的这篇优秀文章

例如,假设有类似 aten::sin(Tensor) 的东西。实际上发生了什么?首先,会检查 Tensor 是否需要 grad。如果是,调用 aten::sin_with_backward (这不是真实的算子,但本质上这构建了反向传播)。然后,如果 Tensor 在 CUDA 上,调度到 aten::sin_with_backward_cuda 。除了自动求导,自动混合精度或 vmap 等功能也是通过调度器实现的。

本质上,调度器负责 PyTorch 提供的核心功能。由于其在 PyTorch 中的核心地位,它允许以其他方法无法提供的深度集成到框架中。因此,调度器也是扩展 PyTorch 功能的核心位置之一。例如,Functorch 的 vmap 可以透明地与 PyTorch 中的几乎所有功能(包括自动求导)无缝集成。为什么?因为它存在于调度器中。

为什么要有 dispatcher 系统?#

如果你想想看,PyTorch 做的事情其实相当令人惊讶。你可以用普通的 Python 代码,然后仅通过在输入上设置 requires_grad ,它就能做完全不同的事情——计算梯度!

PyTorch 的 dispatcher 系统是底层动态 dispatcher 系统的示例实现。例如,可以考虑设备 dispatcher 的一种实现方式

def sin(x: Tensor):
    if x.device == 'cuda':
        return sin_cuda(x)
    else:
        return sin_cpu(x)

除了相当丑陋之外,这种方式还引发了组合性问题——无法在不修改实际函数实现的情况下扩展 sin !例如,假设要添加 vmap。是否需要在函数内部添加另一个条件语句?

因此,允许 dispatcher 根据输入的属性来决定调度哪个 sin 的实现。现在有了类似的东西。

def sin(x: Tensor[requires_grad=False]): return sin_without_grad(x)
def sin(x: Tensor[requires_grad=True]): return sin_with_grad(x)
def sin(x: Tensor[is_batched=True]): return sin_batched(x)

更好的是,在许多情况下,实际上可以重用其他实现。例如, sin_with_grad 可能仍然会在某个地方调用 sin_without_grad 。例如,也许它看起来像这样:

def sin(x: Tensor[requires_grad=True]):
    no_grad_x = x.requires_grad(False)
    out: Tensor[requires_grad=True] = sin(no_grad_x: Tensor[requires_grad=False])
    out.register_backwards_function(sin)
    return out.requires_grad(True)

顺便提一下,将这种特殊行为视为包装子类而不是张量的属性可能更合理。因此,上述示例可能看起来像这样:

def sin(x: Tensor) # Base tensor, just calls sin
def sin(x: GradTensor(Tensor)): # Wrapper gradient tensor that tracks graadients
def sin(x: BatchedTensor(Tensor)): # Wrapper batched tensor that performs vmap

事实上,许多其他功能也可以用这种方式实现!例如日志记录、跟踪、FLOP 计数、vmap、对角张量、掩码张量等!例如(仅伪代码)

FLOP 计数

flop_count = 0
def sin(x: FlopTensor(Tensor)):
    unwrap_x: Tensor = x.elem  # Unwraps FlopTensor to get the underlying Tensor
    flop_count += get_sin_flops(x.shape)  # Counts flops
    out = sin(unwrap_x)  # Calls sin on the unwrapped tensor (i.e. redispatches)
    return FlopTensor(out)

Tracer

def ProxyTensor(Tensor):
    elem: Tensor
    proxy: Proxy
    
def sin(x: ProxyTensor(Tensor)):
   proxy = x.proxy
   unwrap_x = x.elem
   out = sin(unwrap_x)
   proxy_out = proxy.call_function('sin')   
   return ProxyTensor(out, proxy_out)

基本上,dispatcher 允许你做各种各样的事情,并以可组合的方式覆盖各种 PyTorch 行为。但……它带来了很多限制。首先,向分发器注册新功能……需要与 PyTorch 核心团队沟通。但更重要的是,注册这些功能需要在 C++ 中完成!

因此,作为高层次的目标, __torch_dispatch__ 允许你从 Python 中利用 dispatcher 的所有功能!

__torch_dispatch__ 为什么重要?#

看一下在 PyTorch 中调用算子的典型流程,以及在哪些地方可以修改行为。

这是一张 __torch_dispatch__ 与 vmap 工作方式的示意图。实线箭头表示实际走过的路径,虚线箭头表示根据分发键的不同,可能走过的路径。

请注意:

  1. __torch_dispatch__ 位于 vmap 行为之后(因此可以捕获它),

  2. __torch_dispatch__ 是唯一从 C++ 返回到 Python 的途径。

简而言之,目前 PyTorch 中的几乎所有扩展点(少数例外情况除外)都在步骤 1 之前完成。这意味着在某种程度上,这些功能都无法了解其背后的机制!这限制了许多潜在的功能。

以计算 FLOP 为例,PyTorch 早期的 FLOP 计数器都是在框架之上实现的,通常是在模块级别实现的。这种方式在一定程度上可行,但一旦用户使用非标准模块或在模块内部进行操作,就会出现问题。后来,人们开始在 PyTorch 框架内部实现计数器,但仍然在 C++ 之上(即 __torch_function__ 和 FX),这使得他们能够捕获模块内的算子。但是……这些方法从未能够捕获反向传播过程,也无法捕获雅可比矩阵或海森矩阵的 FLOP 计数。

只有通过在 C++调度器中与 __torch_dispatch__ 集成,才能创建能够捕获反向 FLOPs 的 FLOP 计数器。

基本上, __torch_function__ 只允许你在 Python 中进行修改,但如果你想控制 PyTorch 中发生的一切?你需要使用 __torch_dispatch__

__torch_dispatch__ 长啥样?#

看简单的例子,比如说,你想将每个 aten::add 替换为 aten::sub

class FooTensor(torch.Tensor):
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        # First, we must unwrap the wrapper tensors to get the inner tensor object
        def unwrap(x):
                return x.elem if isinstance(x, FooTensor) else x
                
        args = tree_map(unwrap, args)
        kwargs = tree_map(unwrap, kwargs)
        # Now, we check the function to determine how to handle it. If it's 
        # aten.add, then we call aten.sub. Otherwise, we pass through to 
        # the original function
        if func == torch.ops.aten.add:
            out = torch.ops.aten.sub(*args, **kwargs)
        else:
            out = func(*args, **kwargs)
        
        # Now, we want to continue propagating this tensor, so we rewrap Tensors in
        # our custom tensor subclass
        def wrap(x):
            return FooTensor(x) if isinstance(x, Tensor) else x
            
        return tree_map(wrap, out)

如你所见, __torch_dispatch__ 提供了极大的灵活性。对于每个 ATen 算子,都可以对其进行任意处理,包括:

  • 在算子之前执行一些操作(包括记录日志或实际修改值)

  • 在算子之后执行操作(同上)。

  • 调用任意实现的函数(例如调用 NumPy 或另一个编译器)。

  • 重新调用默认实现。

请注意,3 - 能够调用任意实现的函数,这在实际张量表示方面带来了极大的灵活性。例如,可以将张量表示为 Int8 量化张量,反量化张量,然后调用原始函数。

def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
    def unwrap(e):
        if isinstance(e, QuantTensor):
            return cls.dequantize(e.mat, e.row_factor, e.column_factor, e.requires_grad, e.dtype)
        else:
            return e
    out = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))

为了使示例更清晰,略去了许多细节。如果你想实际尝试 __torch_dispatch__ 或了解更多细节,请参阅 albanD/subclass_zoo

长期愿景是什么?#

以用 PyTorch 编写的 ResNet18 模型为例。这个程序是什么?一种看待它的方法是,它只是高级的汇编代码表示。但……它不仅仅对应于单一的一系列指令。根据输入的不同,它可能在 CPU、GPU 或 TPU 上运行。根据是否需要 grad,它可能在反向传播时保存激活值,也可能不保存。根据是否启用了自动混合精度,它可能自动在 float32float16 之间转换张量,也可能不转换。

或许更好的一种看待方式是将其视为模型的抽象表示。就像数学公式可以被翻译成代码一样,PyTorch 将这个模型的抽象表示翻译成数以亿计的实际执行代码。但不仅可以做到上述示例中的这些事情。

保持模型代码不变的情况下,使用 __torch_dispatch__ ,用户应该能够

而且,这些操作应该能够组合。它们应该能够计算出以对角矩阵表示掩码的 MaskedTensors 的逐样本梯度,利用张量/模型/数据并行性进行并行计算,然后对整个算子进行追踪,以便将其传递给编译器。

这些事情以前并非不可能做到,只是需要在 PyTorch 核心部分投入大量资源。 __torch_dispatch__ 只是将这一点开放给了更多人。

附注:关于追踪的部分,这可能会非常重要。想要做的所有这些张量子类都是对底层内核算子的抽象。但是,通过追踪,可以穿透这些抽象层次,直接访问底层的张量算子。

例如,可以(假设性地)添加自定义的“4 位蝴蝶稀疏张量”,只要它们的所有底层算子都是 ATen 算子,就可以使用这个张量进行训练/评估(全部在 Python 中完成),然后导出张量语义,用于移动设备!需要注意的是,这并不是一种假设,目前就可以做到 :slight_smile: