Torch Dispatch Mode#
不幸的是,有些函数不接受张量输入。这意味着子类方法不能用来覆盖所有 PyTorch 函数的行为。此外,如果用例需要拦截每个函数调用,将每个张量都改为子类可能会过于侵入性。
Torch Dispatch Mode 为用户提供了一种在 __torch_dispatch__ 级别拦截所有调用的方法,包括工厂函数。您可能见过其他关于使用张量子类实现类似目的的笔记(或整个仓库)。模式类似于子类,但它们还允许捕获工厂函数,或者不接受张量参数的函数,例如 torch.randn 或 torch.ones。
from torch.utils._python_dispatch import TorchDispatchMode
为了解决这一用例,引入了“模式”的概念。这些模式适用于__torch_function__和__torch_dispatch__的重写,分别通过继承torch.overrides.TorchFunctionMode和torch.utils._python_dispatch.TorchDispatchMode来创建,并作为上下文管理器使用。
为了简化描述其与子类和其他模式的交互方式,每当进入某个模式的上下文管理器时,每个函数的行为都像是在参数列表开头多了一个以该模式作为子类的Tensor参数。这意味着特别是所有模式处理程序都会在任何子类处理程序之前被调用,并且对应于内部上下文管理器的模式总是最先运行。
还需要注意的是,在特定的模式处理程序中,此特定模式会被禁用,并且可以通过执行 with self: 手动重新启用。
基础#
以下是一段基本的 PrintingMode 代码,它会打印出它看到的每次调用
class PrintingMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
print(f"{func.__module__}.{func.__name__}({args}, {kwargs})")
return func(*args, **kwargs)
状态化#
模式相比子类的好处是,它们可以携带状态。所以如果你想使用相同的 PrintingMode,但让所有内容写入 logger 对象而不是仅仅打印出来,你可以这样做
class LoggingMode(TorchDispatchMode):
def __init__(self, logger):
self.logger = logger
return super().__init__()
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
self.logger.log(f"{func.__module__}.{func.__name__}", args, kwargs)
return func(*args, **kwargs)
注意这里,相同的模式总是写入同一个 Logger 对象。你可以同时运行多个使用不同 logger 对象的 LoggerMode
同样需要注意的是,这需要实现好的 Logger 对象。如果你需要已经工作的版本,PyTorch 已经有完全工作的 [LoggingTensorMode]。
使用模式#
然后,要使用这些模式,你可以将你想记录的日志的调用包裹起来:
with LoggingMode(logger):
<call to be logged>
你可能已经看到过别人使用类似的东西
with enable_torch_dispatch_mode(LoggingMode(logger)):
在大多数情况下,它们会执行相同的操作。不过,如果你要在已经使用某种模式运行的代码中添加模式,你需要使用 with LoggingMode() (如果你使用启用版本,会看到错误提示)。
模式帮助调试和统计:
它还用于以下测试中:
什么时候应该使用张量子类?什么时候应该使用模式?什么时候应该同时使用两者?#
先说明一下这里有一些微妙之处,在模式出现之前,通常是通过使用子类来解决类似的问题。个人认为模式编写起来更快,且出错的可能性更少,所以建议从模式开始使用,如果有必要的话再添加子类。
在这一声明之后,基本的原则是:如果你只是想查看所有击中 __torch_dispatch__ 的函数,你应该使用模式。如果你关心跟踪张量参数在不同调用之间的传递,或者传播特定的张量状态(如分发键),你应该使用子类。再详细分解一下:
模式#
在这里,考虑刚刚看到的调试模式。对于这些模式,只关心被调用的函数及其传入的参数。不需要跟踪被调用的张量。此外,打算使用传入的张量来运行所有函数。因此,模式在这里是足够的。
子类#
需要子类化的例子包括 ProxyTensors 和 FakeTensors。
对于 ProxyTensors,AOTAutograd 的基础,需要知道当同一个参数被重复使用时,以便提供正确的简洁 graphs。在这里,希望图使用相同的 Proxy 对象,因此需要子类。
仿张量在跟踪过程中避免实际计算的成本。为了不运行计算,需要使用元内核(meta kernel),这意味着需要在张量上使用 meta dispatch ke,而不需要在用户的张量上使用。因此,需要将子类传递下去,以正确设置整个计算过程中的 dispatch 键和设备类型。
需要注意的是,这两个子类也使用一种模式来捕获工厂函数并将其包装在子类中。这通常是预期的行为,并且一般建议在子类中使用这种模式。还有一些关于如何使用的最佳实践。
示例#
这里有一个示例,展示了每种类型的日志模式:
import torch
from torch.overrides import TorchFunctionMode, resolve_name
from torch.utils._python_dispatch import TorchDispatchMode
class FunctionLog(TorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
print(f"Function Log: {resolve_name(func)}(*{args}, **{kwargs})")
return func(*args, **(kwargs or {}))
class DispatchLog(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args, kwargs=None):
print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
return func(*args, **(kwargs or {}))
def f():
a = torch.rand(10, requires_grad=True)
b = a * 2
b.sum().backward()
print("TorchFunctionMode logging:")
with FunctionLog():
f()
print("TorchDispatchMode logging:")
with DispatchLog():
f()
TorchFunctionMode logging:
Function Log: torch.rand(*(10,), **{'requires_grad': True})
Function Log: torch.Tensor.mul(*(tensor([0.2549, 0.8917, 0.8051, 0.7601, 0.6900, 0.5605, 0.2226, 0.2205, 0.6670,
0.5728], requires_grad=True), 2), **None)
Function Log: torch.Tensor.sum(*(tensor([0.5099, 1.7835, 1.6103, 1.5202, 1.3800, 1.1211, 0.4452, 0.4410, 1.3339,
1.1455], grad_fn=<MulBackward0>),), **None)
Function Log: torch.Tensor.backward(*(tensor(11.2904, grad_fn=<SumBackward0>),), **{'gradient': None, 'retain_graph': None, 'create_graph': False, 'inputs': None})
TorchDispatchMode logging:
Dispatch Log: aten.rand.default(*([10],), **{'device': device(type='cpu'), 'pin_memory': False})
Dispatch Log: aten.mul.Tensor(*(tensor([0.6559, 0.9715, 0.8585, 0.9084, 0.1064, 0.6898, 0.7138, 0.3887, 0.8222,
0.1663], requires_grad=True), 2), **{})
Dispatch Log: aten.sum.default(*(tensor([1.3117, 1.9431, 1.7171, 1.8167, 0.2129, 1.3795, 1.4275, 0.7773, 1.6444,
0.3325], grad_fn=<MulBackward0>),), **{})
Dispatch Log: aten.ones_like.default(*(tensor(12.5628, grad_fn=<SumBackward0>),), **{'pin_memory': False, 'memory_format': torch.preserve_format})
Dispatch Log: aten.expand.default(*(tensor(1.), [10]), **{})
Dispatch Log: aten.mul.Tensor(*(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), 2), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})