Meta 设备#
"meta" 设备是抽象设备,表示仅记录元数据而不包含实际数据的张量。Meta 张量有两个主要用途:
模型可以加载到元设备上,这样你就能加载模型的表示形式,而无需实际将参数加载到内存中。如果你需要在加载实际数据之前对模型进行变换,这可能会很有帮助。
大多数算子都可以在元张量(meta tensors)上执行,从而生成新的元张量,这些新张量描述了如果在真实张量上执行该算子后结果会是什么样子。你可以利用这一点来进行抽象分析,而无需花费时间进行计算或占用空间来表示实际的张量。由于元张量不包含真实数据,因此你无法执行像 torch.nonzero() 或 item() 这样的数据依赖性运算。在某些情况下,并非所有设备类型(例如 CPU 和 CUDA)对某个算子的输出元数据都完全相同;在这种情况下,通常倾向于忠实表示 CUDA 的行为。
警告
虽然原则上元张量计算应该始终比等效的 CPU/CUDA 计算更快,但许多元张量实现是用 Python 编写的,并且尚未移植到 C++ 以提高速度,因此您可能会发现使用小型 CPU 张量时,框架的绝对延迟更低。
与元张量工作的惯用法#
可以通过指定 map_location='meta',使用 torch.load() 将对象加载到元设备上:
from pathlib import Path
temp_dir = Path(".temp")
temp_dir.mkdir(exist_ok=True)
import torch
torch.save(torch.randn(2), temp_dir/'foo.pt')
torch.load(temp_dir/'foo.pt', map_location='meta')
tensor(..., device='meta', size=(2,))
如果你有一段任意代码,在没有明确指定设备的情况下执行张量构建操作,你可以通过使用 torch.device() 上下文管理器来覆盖它,改为在元设备上进行构建:
with torch.device('meta'):
print(torch.randn(30, 30))
tensor(..., device='meta', size=(30, 30))
这在 NN 模块构建中尤其有用,因为在初始化时你通常无法显式传递设备进来:
from torch.nn.modules import Linear
with torch.device('meta'):
print(Linear(20, 30))
Linear(in_features=20, out_features=30, bias=True)
你不能直接将元张量转换为 CPU/CUDA 张量,因为元张量不存储任何数据,不知道新张量的正确数据值应该是什么。
torch.ones(5, device='meta').to("cpu")
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In[8], line 1
----> 1 torch.ones(5, device='meta').to("cpu")
NotImplementedError: Cannot copy out of meta tensor; no data!
使用工厂函数如 torch.empty_like() 来明确指定你希望如何填充缺失数据。
NN 模块提供了便捷方法 torch.nn.Module.to_empty(),允许您将模块移动到另一个设备上,但此时所有参数都处于未初始化状态。您需要手动显式地重新初始化这些参数:
from torch.nn.modules import Linear
with torch.device('meta'):
m = Linear(20, 30)
m.to_empty(device="cpu")
Linear(in_features=20, out_features=30, bias=True)
torch._subclasses.meta_utils 包含未公开的工具,用于获取任意张量并构建具有高保真度的等效元张量。这些 API 属于实验性质,可能会随时以不兼容的方式更改。