FlopCounterMode#
参考:the-ideal-pytorch-flop-counter-with-torch-dispatch
1.能在算子级别计数浮点运算次数,2.(可选地)在模块层级聚合这些计数,3.捕获反向传播中的浮点运算次数,4.并在即时执行模式下工作。哦,你还可以用它通过任意变换(如 vmap)来计算雅可比矩阵或海森矩阵的浮点运算次数!
from torch.utils.flop_counter import FlopCounterMode
FlopCounterMode??
import torch
import torchvision.models as models
with torch.device('meta'):
inp = torch.randn((8, 3, 224, 224))
mod = models.resnet18()
with FlopCounterMode() as flop_counter:
mod(inp)
Module FLOP % Total
-------------------- ------- ---------
ResNet 29.025B 100.00%
- aten.convolution 29.017B 99.97%
- aten.addmm 0.008B 0.03%
ResNet.conv1 1.888B 6.51%
- aten.convolution 1.888B 6.51%
ResNet.fc 0.008B 0.03%
- aten.addmm 0.008B 0.03%
ResNet.layer1 7.399B 25.49%
- aten.convolution 7.399B 25.49%
ResNet.layer2 6.577B 22.66%
- aten.convolution 6.577B 22.66%
ResNet.layer3 6.577B 22.66%
- aten.convolution 6.577B 22.66%
ResNet.layer4 6.577B 22.66%
- aten.convolution 6.577B 22.66%
import torch
import torchvision.models as models
with torch.device('cpu'):
inp = torch.randn((8, 3, 224, 224))
mod = models.resnet18()
with FlopCounterMode() as flop_counter:
mod(inp)
Module FLOP % Total
-------------------- ------- ---------
ResNet 29.025B 100.00%
- aten.convolution 29.017B 99.97%
- aten.addmm 0.008B 0.03%
ResNet.conv1 1.888B 6.51%
- aten.convolution 1.888B 6.51%
ResNet.fc 0.008B 0.03%
- aten.addmm 0.008B 0.03%
ResNet.layer1 7.399B 25.49%
- aten.convolution 7.399B 25.49%
ResNet.layer2 6.577B 22.66%
- aten.convolution 6.577B 22.66%
ResNet.layer3 6.577B 22.66%
- aten.convolution 6.577B 22.66%
ResNet.layer4 6.577B 22.66%
- aten.convolution 6.577B 22.66%
import torch
import torchvision.models as models
with torch.device('meta'):
inp = torch.randn((8, 3, 224, 224))
mod = models.resnet18()
with FlopCounterMode() as flop_counter:
mod(inp)
Module FLOP % Total
-------------------- ------- ---------
ResNet 29.025B 100.00%
- aten.convolution 29.017B 99.97%
- aten.addmm 0.008B 0.03%
ResNet.conv1 1.888B 6.51%
- aten.convolution 1.888B 6.51%
ResNet.fc 0.008B 0.03%
- aten.addmm 0.008B 0.03%
ResNet.layer1 7.399B 25.49%
- aten.convolution 7.399B 25.49%
ResNet.layer2 6.577B 22.66%
- aten.convolution 6.577B 22.66%
ResNet.layer3 6.577B 22.66%
- aten.convolution 6.577B 22.66%
ResNet.layer4 6.577B 22.66%
- aten.convolution 6.577B 22.66%
from torch.utils.flop_counter import FlopCounterMode
import torch
import torchvision.models as models
with torch.device('meta'):
inp = torch.randn((1, 3, 224, 224))
mod = models.resnet18()
with FlopCounterMode() as flop_counter:
mod(inp).sum()
Module FLOP % Total
-------------------- --------- ---------
ResNet 3628.147M 100.00%
- aten.convolution 3627.123M 99.97%
- aten.addmm 1.024M 0.03%
ResNet.conv1 236.028M 6.51%
- aten.convolution 236.028M 6.51%
ResNet.fc 1.024M 0.03%
- aten.addmm 1.024M 0.03%
ResNet.layer1 924.844M 25.49%
- aten.convolution 924.844M 25.49%
ResNet.layer2 822.084M 22.66%
- aten.convolution 822.084M 22.66%
ResNet.layer3 822.084M 22.66%
- aten.convolution 822.084M 22.66%
ResNet.layer4 822.084M 22.66%
- aten.convolution 822.084M 22.66%
print(f"{flop_counter.get_total_flops()/1e9:.2f} GFLOPs")
3.63 GFLOPs