子类化 torch.Tensor#
从 1.7.0 版本开始,torch.Tensor 上的方法以及应用于 torch.Tensor 子类的公共 torch.* 命名空间函数将返回子类实例,而非 torch.Tensor 实例:
import torch
class SubTensor(torch.Tensor):
...
type(torch.add(SubTensor([0]), SubTensor([1]))).__name__, type(torch.add(SubTensor([0]), torch.tensor([1]))).__name__
('SubTensor', 'SubTensor')
如果存在多个子类,默认会选择层次结构中最底层的那个。如果无法以唯一方式确定这种情况,则会引发 TypeError 错误:
>>> type(torch.add(SubTensor2([0]), SubTensor([1]))).__name__
'SubTensor2'
>>> type(torch.add(SubTensor2([0]), torch.tensor([1]))).__name__
'SubTensor2'
>>> torch.add(SubTensor([0]), OtherSubTensor([1]))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: no implementation found for 'torch.add' on types that implement __torch_function__: [SubTensor, OtherSubTensor]
若希望对所有张量方法进行全局覆盖,可以使用 __torch_function__ 。以下是记录所有函数/方法调用的示例:
class LoggingTensor(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
# NOTE: Logging calls Tensor.__repr__, so we can't log __repr__ without infinite recursion
if func is not torch.Tensor.__repr__:
logging.info(f"func: {func.__name__}, args: {args!r}, kwargs: {kwargs!r}")
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
然而,如果希望覆盖 Tensor 子类上的方法,可以通过直接覆盖该方法(通过为子类定义它),或者使用 __torch_function__ 并与 func 匹配来实现。
在 __torch_function__ 中,子类应当始终调用 super().__torch_function__(func, ...) 而不是直接调用 func ,就像在 1.7.0 版本之前的做法一样。如果未能这样做,可能会导致 func 递归回 __torch_function__ ,从而引发无限递归。
扩展 torch 的 Tensor 包装器类型#
另一个有用的案例是封装张量的类型,无论是作为属性还是通过子类化。下面实现了这种类型的特例,即 MetadataTensor,它将元数据字典附加到张量上,并通过 torch 算子传播。由于这是对完整 torch API 的通用封装,不需要单独实现每个重写,因此可以使 __torch_function__ 的实现对允许的算子更加宽松:
class MetadataTensor(object):
def __init__(self, data, metadata=None, **kwargs):
self._t = torch.as_tensor(data, **kwargs)
self._metadata = metadata
def __repr__(self):
return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
metadatas = tuple(a._metadata for a in args if hasattr(a, '_metadata'))
args = [getattr(a, '_t', a) for a in args]
assert len(metadatas) > 0
ret = func(*args, **kwargs)
return MetadataTensor(ret, metadata=metadatas[0])
这个简单的实现不一定会适用于 torch API 中的每一个函数,但它足以涵盖大多数常见算子:
metadata = {'owner': 'Ministry of Silly Walks'}
m = MetadataTensor([[1, 2], [3, 4]], metadata=metadata)
t = torch.tensor([[1, 2], [1, 2]])
torch.add(t, m), torch.mul(t, m)
(Metadata:
{'owner': 'Ministry of Silly Walks'}
data:
tensor([[2, 4],
[4, 6]]),
Metadata:
{'owner': 'Ministry of Silly Walks'}
data:
tensor([[1, 4],
[3, 8]]))