扩展 torch 原生 API#
虽然 __torch_function__ 允许人们有效地扩展 PyTorch 的纯 Python 组件的行为,但它不允许扩展用 C++ 实现的 PyTorch 部分。为此,Tensor 子类也可以定义 __torch_dispatch__ ,它将能够在 C++ 层面覆盖行为。
要有效使用此功能,了解 PyTorch 原生部分的实现方式至关重要。其中最重要的组件是称之为“dispatcher”的部分(最佳描述可在这篇博客文章中找到,尽管内容稍有过时)。顾名思义,它负责为特定函数调用选择正确的后端函数。例如,当调用 torch.add(a, b) 时,dispatcher 会检查两个参数,确定该特定调用应使用哪些“功能”(自动微分、自动类型转换、函数化等)和哪些“后端”(CPU、CUDA、MPS等),最终调用所有正确的内核。内核常做的一件事是“redispatch”。例如,在 GPU 上使用自动类型转换运行神经网络时,首次调用将是处理任何潜在自动类型转换逻辑并向下 redispatch 的自动类型转换内核。接下来的功能将是自动微分,它会正确创建自动微分图并向下 redispatch。最后,到达 CUDA 的后端内核,它将启动正确的 CUDA 内核并返回最终结果。在返回过程中,自动微分会将计算图附加到输出上,最后自动类型转换有机会在退出时进行所需的任何更新。
dispatcher 的一个配置是调用所有这些功能和后端键的顺序。最新的列表及其顺序可以在 DispatchKey.h 文件中的 DispatchKey 枚举中找到。为了扩展 torch 的目的,本次讨论中重要的顺序子集是:
vmap -> Autocast -> Autograd -> ZeroTensor -> Neg/Conj -> Functionalize -> Python -> Backends
本次讨论的关键在于 Python,因为每个定义了 __torch_dispatch__ 方法的 Tensor 子类都会调用这一功能。正是从这里开始调用用户定义的方法,并可以任意覆盖其行为。随后再次调用提供的 func 将执行"redispatch"。
这个实现有一些重要的含义:
这段代码运行在“所有特性之下”。因此,它只像常规后端一样负责生成每个 Tensor 的输出值(并且可以也应该忽略所有高级特性,如 autograd、autocast 等)。
如果任何高级特性在不进行重新分发的情况下实现了某个函数,它将永远不会到达 Python 关键,因此
__torch_dispatch__回调也永远不会被触发。这特别发生在 CompositeImplicitAutograd 函数中,它们在 Autograd 级别被评估而不进行重新分发。这是因为 CompositeImplicitAutograd 函数通过隐式调用其他原生操作来指定其 autograd 公式,所以在 Autograd 级别,该函数被分解为其原生算子,然后对这些算子进行评估。当回调到 Python 并包装结果时,与常规的 PyTorch Python/C++绑定使用相同的转换。特别是,某些对象无法在 Python 中表示,需要特殊处理(例如未定义的张量会变成
None)。原生函数被懒惰地填充为
torch.ops.{namespace}.{func_name}.{overload_name}可调用的 Python 对象,以便从 Python 中轻松地与之交互。传递给__torch_dispatch__的 func 对象始终是此命名空间中的一个条目。此命名空间可用于直接调用原生算子,并绕过通常的 Python API 和绑定代码。
与 __torch_function__ 能够介入 Torch 所有 Python API 和 Tensor 方法类似,__torch_dispatch__ 可以拦截所有对 aten 原生 API 的调用。需要注意的是,所有 Tensor 上的方法在进入分发器之前都会被转换为函数调用,因此在这里都会以函数调用的形式出现:torch.add(a, 2) 和 a + 2 最终会产生完全相同的 aten 调用。这些函数大多定义在 native_functions.yaml 中,该文件指定了这些函数的属性及其后端实现。然后,它们的实现以及指定的特性会通过代码生成自动注册。一些更特殊的函数或特性也会在 C++ 代码库的其他地方或用户定义的 C++ 扩展中进行注册。
也可以使用 torch.library() 添加新的原生函数。这一 Python 功能允许为原生函数定义和/或添加新的实现。这可用于添加缺失的内核、替换现有的内核或定义全新的原生函数。
你可以在 subclass zoo 仓库中找到许多基于 __torch_dispatch__ 的子类示例。
__torch_dispatch__ 调用约定#
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
...
当用户使用带有 __torch_dispatch__ 的输入调用算子时,该调用可能会被转发到 __torch_dispatch__ 。在调用 __torch_dispatch__ 之前,args 和 kwargs 会被标准化,即:
kwargs 由算子模式中的仅限关键字参数组成。如果某个 kwarg 等于其默认值(在模式中),则不会传递该参数。
args 包含所有其他参数,无论它们如何传递给算子(位置参数与关键字参数)。如果某个参数等于其默认值,并且它是最右边的位置参数,或者它右边的所有参数都没有传递,那么它将不会被传递。