__torch_dispatch__ 是什么#
参考:what-and-why-is-torch-dispatch & lets-talk-about-the-pytorch-dispatcher
简而言之:__torch_dispatch__ 允许你利用 dispatcher 的强大功能任意扩展 PyTorch,但现在是从 Python 中实现。这有望为 PyTorch 带来全新的灵活性,全部都在 Python 中实现。
PyTorch 的核心是什么?(剧透:dispatcher)#
从宏观上看,PyTorch 做了两件事。
根据输入,确定要运行的合适内核,以及是 CUDA 实现还是 CPU 实现。
根据输入,在自动求导图中注册合适的事物。
这两件事使得 PyTorch 从“numpy”变成了“支持 CUDA 和自动求导的 numpy”。最关键的是,这两件事都是通过调度器(dispatcher)在 PyTorch 中实现的。
核心来说,调度器是一个系统,根据输入的属性决定调用哪个函数。要了解更多,建议阅读 Edward Yang 的这篇优秀文章。
例如,假设有类似 aten::sin(Tensor) 的东西。实际上发生了什么?首先,会检查 Tensor 是否需要 grad。如果是,调用 aten::sin_with_backward (这不是真实的算子,但本质上这构建了反向传播)。然后,如果 Tensor 在 CUDA 上,调度到 aten::sin_with_backward_cuda 。除了自动求导,自动混合精度或 vmap 等功能也是通过调度器实现的。
本质上,调度器负责 PyTorch 提供的核心功能。由于其在 PyTorch 中的核心地位,它允许以其他方法无法提供的深度集成到框架中。因此,调度器也是扩展 PyTorch 功能的核心位置之一。例如,Functorch 的 vmap 可以透明地与 PyTorch 中的几乎所有功能(包括自动求导)无缝集成。为什么?因为它存在于调度器中。
为什么要有 dispatcher 系统?#
如果你想想看,PyTorch 做的事情其实相当令人惊讶。你可以用普通的 Python 代码,然后仅通过在输入上设置 requires_grad ,它就能做完全不同的事情——计算梯度!
PyTorch 的 dispatcher 系统是底层动态 dispatcher 系统的示例实现。例如,可以考虑设备 dispatcher 的一种实现方式
def sin(x: Tensor):
if x.device == 'cuda':
return sin_cuda(x)
else:
return sin_cpu(x)
除了相当丑陋之外,这种方式还引发了组合性问题——无法在不修改实际函数实现的情况下扩展 sin !例如,假设要添加 vmap。是否需要在函数内部添加另一个条件语句?
因此,允许 dispatcher 根据输入的属性来决定调度哪个 sin 的实现。现在有了类似的东西。
def sin(x: Tensor[requires_grad=False]): return sin_without_grad(x)
def sin(x: Tensor[requires_grad=True]): return sin_with_grad(x)
def sin(x: Tensor[is_batched=True]): return sin_batched(x)
更好的是,在许多情况下,实际上可以重用其他实现。例如, sin_with_grad 可能仍然会在某个地方调用 sin_without_grad 。例如,也许它看起来像这样:
def sin(x: Tensor[requires_grad=True]):
no_grad_x = x.requires_grad(False)
out: Tensor[requires_grad=True] = sin(no_grad_x: Tensor[requires_grad=False])
out.register_backwards_function(sin)
return out.requires_grad(True)
顺便提一下,将这种特殊行为视为包装子类而不是张量的属性可能更合理。因此,上述示例可能看起来像这样:
def sin(x: Tensor) # Base tensor, just calls sin
def sin(x: GradTensor(Tensor)): # Wrapper gradient tensor that tracks graadients
def sin(x: BatchedTensor(Tensor)): # Wrapper batched tensor that performs vmap
事实上,许多其他功能也可以用这种方式实现!例如日志记录、跟踪、FLOP 计数、vmap、对角张量、掩码张量等!例如(仅伪代码)
FLOP 计数
flop_count = 0
def sin(x: FlopTensor(Tensor)):
unwrap_x: Tensor = x.elem # Unwraps FlopTensor to get the underlying Tensor
flop_count += get_sin_flops(x.shape) # Counts flops
out = sin(unwrap_x) # Calls sin on the unwrapped tensor (i.e. redispatches)
return FlopTensor(out)
Tracer
def ProxyTensor(Tensor):
elem: Tensor
proxy: Proxy
def sin(x: ProxyTensor(Tensor)):
proxy = x.proxy
unwrap_x = x.elem
out = sin(unwrap_x)
proxy_out = proxy.call_function('sin')
return ProxyTensor(out, proxy_out)
基本上,dispatcher 允许你做各种各样的事情,并以可组合的方式覆盖各种 PyTorch 行为。但……它带来了很多限制。首先,向分发器注册新功能……需要与 PyTorch 核心团队沟通。但更重要的是,注册这些功能需要在 C++ 中完成!
因此,作为高层次的目标, __torch_dispatch__ 允许你从 Python 中利用 dispatcher 的所有功能!
__torch_dispatch__ 为什么重要?#
看一下在 PyTorch 中调用算子的典型流程,以及在哪些地方可以修改行为。
这是一张 __torch_dispatch__ 与 vmap 工作方式的示意图。实线箭头表示实际走过的路径,虚线箭头表示根据分发键的不同,可能走过的路径。

请注意:
__torch_dispatch__位于 vmap 行为之后(因此可以捕获它),__torch_dispatch__是唯一从 C++ 返回到 Python 的途径。
简而言之,目前 PyTorch 中的几乎所有扩展点(少数例外情况除外)都在步骤 1 之前完成。这意味着在某种程度上,这些功能都无法了解其背后的机制!这限制了许多潜在的功能。
以计算 FLOP 为例,PyTorch 早期的 FLOP 计数器都是在框架之上实现的,通常是在模块级别实现的。这种方式在一定程度上可行,但一旦用户使用非标准模块或在模块内部进行操作,就会出现问题。后来,人们开始在 PyTorch 框架内部实现计数器,但仍然在 C++ 之上(即 __torch_function__ 和 FX),这使得他们能够捕获模块内的算子。但是……这些方法从未能够捕获反向传播过程,也无法捕获雅可比矩阵或海森矩阵的 FLOP 计数。
只有通过在 C++调度器中与 __torch_dispatch__ 集成,才能创建能够捕获反向 FLOPs 的 FLOP 计数器。
基本上, __torch_function__ 只允许你在 Python 中进行修改,但如果你想控制 PyTorch 中发生的一切?你需要使用 __torch_dispatch__。
__torch_dispatch__ 长啥样?#
看简单的例子,比如说,你想将每个 aten::add 替换为 aten::sub 。
class FooTensor(torch.Tensor):
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
# First, we must unwrap the wrapper tensors to get the inner tensor object
def unwrap(x):
return x.elem if isinstance(x, FooTensor) else x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
# Now, we check the function to determine how to handle it. If it's
# aten.add, then we call aten.sub. Otherwise, we pass through to
# the original function
if func == torch.ops.aten.add:
out = torch.ops.aten.sub(*args, **kwargs)
else:
out = func(*args, **kwargs)
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
def wrap(x):
return FooTensor(x) if isinstance(x, Tensor) else x
return tree_map(wrap, out)
如你所见, __torch_dispatch__ 提供了极大的灵活性。对于每个 ATen 算子,都可以对其进行任意处理,包括:
在算子之前执行一些操作(包括记录日志或实际修改值)
在算子之后执行操作(同上)。
调用任意实现的函数(例如调用 NumPy 或另一个编译器)。
重新调用默认实现。
请注意,3 - 能够调用任意实现的函数,这在实际张量表示方面带来了极大的灵活性。例如,可以将张量表示为 Int8 量化张量,反量化张量,然后调用原始函数。
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e):
if isinstance(e, QuantTensor):
return cls.dequantize(e.mat, e.row_factor, e.column_factor, e.requires_grad, e.dtype)
else:
return e
out = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
为了使示例更清晰,略去了许多细节。如果你想实际尝试 __torch_dispatch__ 或了解更多细节,请参阅 albanD/subclass_zoo。
长期愿景是什么?#
以用 PyTorch 编写的 ResNet18 模型为例。这个程序是什么?一种看待它的方法是,它只是高级的汇编代码表示。但……它不仅仅对应于单一的一系列指令。根据输入的不同,它可能在 CPU、GPU 或 TPU 上运行。根据是否需要 grad,它可能在反向传播时保存激活值,也可能不保存。根据是否启用了自动混合精度,它可能自动在 float32 和 float16 之间转换张量,也可能不转换。
或许更好的一种看待方式是将其视为模型的抽象表示。就像数学公式可以被翻译成代码一样,PyTorch 将这个模型的抽象表示翻译成数以亿计的实际执行代码。但不仅可以做到上述示例中的这些事情。
保持模型代码不变的情况下,使用 __torch_dispatch__ ,用户应该能够
将权重表示为某种任意更高效的表示形式
低秩逼近(Low rank approximation)
8-bit 量化格式
对角张量(Diagonal Tensors)
线性算子张量(Linear Operator Tensors)
任意 Einsum 张量
以 MaskedTensors 作为输入
以某种任意的惰性执行方式执行(如 LazyTensor)
跟踪计算图中发生的算子(即 AOTAutograd)
…
而且,这些操作应该能够组合。它们应该能够计算出以对角矩阵表示掩码的 MaskedTensors 的逐样本梯度,利用张量/模型/数据并行性进行并行计算,然后对整个算子进行追踪,以便将其传递给编译器。
这些事情以前并非不可能做到,只是需要在 PyTorch 核心部分投入大量资源。 __torch_dispatch__ 只是将这一点开放给了更多人。
附注:关于追踪的部分,这可能会非常重要。想要做的所有这些张量子类都是对底层内核算子的抽象。但是,通过追踪,可以穿透这些抽象层次,直接访问底层的张量算子。
例如,可以(假设性地)添加自定义的“4 位蝴蝶稀疏张量”,只要它们的所有底层算子都是 ATen 算子,就可以使用这个张量进行训练/评估(全部在 Python 中完成),然后导出张量语义,用于移动设备!需要注意的是,这并不是一种假设,目前就可以做到 :slight_smile: