FlopCounterMode

FlopCounterMode#

参考:the-ideal-pytorch-flop-counter-with-torch-dispatch

1.能在算子级别计数浮点运算次数,2.(可选地)在模块层级聚合这些计数,3.捕获反向传播中的浮点运算次数,4.并在即时执行模式下工作。哦,你还可以用它通过任意变换(如 vmap)来计算雅可比矩阵或海森矩阵的浮点运算次数!

from torch.utils.flop_counter import FlopCounterMode
FlopCounterMode??

Hide code cell output

Init signature:
FlopCounterMode(
    mods: Union[torch.nn.modules.module.Module, list[torch.nn.modules.module.Module], NoneType] = None,
    depth: int = 2,
    display: bool = True,
    custom_mapping: Optional[dict[Any, Any]] = None,
)
Source:        
class FlopCounterMode:
    """
    ``FlopCounterMode`` is a context manager that counts the number of flops within its context.

    It does this using a ``TorchDispatchMode``.

    It also supports hierarchical output by passing a module (or list of
    modules) to FlopCounterMode on construction. If you do not need hierarchical
    output, you do not need to use it with a module.

    Example usage

    .. code-block:: python

        mod = ...
        with FlopCounterMode(mod) as flop_counter:
            mod.sum().backward()

    """

    def __init__(
            self,
            mods: Optional[Union[torch.nn.Module, list[torch.nn.Module]]] = None,
            depth: int = 2,
            display: bool = True,
            custom_mapping: Optional[dict[Any, Any]] = None):
        super().__init__()
        self.flop_counts: dict[str, dict[Any, int]] = defaultdict(lambda: defaultdict(int))
        self.depth = depth
        self.display = display
        self.mode: Optional[_FlopCounterMode] = None
        if custom_mapping is None:
            custom_mapping = {}
        if mods is not None:
            warnings.warn("mods argument is not needed anymore, you can stop passing it", stacklevel=2)
        self.flop_registry = {
            **flop_registry,
            **{k: v if getattr(v, "_get_raw", False) else shape_wrapper(v) for k, v in custom_mapping.items()}
        }
        self.mod_tracker = ModuleTracker()

    def get_total_flops(self) -> int:
        return sum(self.flop_counts['Global'].values())

    def get_flop_counts(self) -> dict[str, dict[Any, int]]:
        """Return the flop counts as a dictionary of dictionaries.

        The outer
        dictionary is keyed by module name, and the inner dictionary is keyed by
        operation name.

        Returns:
            Dict[str, Dict[Any, int]]: The flop counts as a dictionary.
        """
        return {k: dict(v) for k, v in self.flop_counts.items()}

    def get_table(self, depth=None):
        if depth is None:
            depth = self.depth
        if depth is None:
            depth = 999999

        import tabulate
        tabulate.PRESERVE_WHITESPACE = True
        header = ["Module", "FLOP", "% Total"]
        values = []
        global_flops = self.get_total_flops()
        global_suffix = get_suffix_str(global_flops)
        is_global_subsumed = False

        def process_mod(mod_name, depth):
            nonlocal is_global_subsumed

            total_flops = sum(self.flop_counts[mod_name].values())

            is_global_subsumed |= total_flops >= global_flops

            padding = " " * depth
            values = []
            values.append([
                padding + mod_name,
                convert_num_with_suffix(total_flops, global_suffix),
                convert_to_percent_str(total_flops, global_flops)
            ])
            for k, v in self.flop_counts[mod_name].items():
                values.append([
                    padding + " - " + str(k),
                    convert_num_with_suffix(v, global_suffix),
                    convert_to_percent_str(v, global_flops)
                ])
            return values

        for mod in sorted(self.flop_counts.keys()):
            if mod == 'Global':
                continue
            mod_depth = mod.count(".") + 1
            if mod_depth > depth:
                continue

            cur_values = process_mod(mod, mod_depth - 1)
            values.extend(cur_values)

        # We do a bit of messing around here to only output the "Global" value
        # if there are any FLOPs in there that aren't already fully contained by
        # a module.
        if 'Global' in self.flop_counts and not is_global_subsumed:
            for value in values:
                value[0] = " " + value[0]

            values = process_mod('Global', 0) + values

        if len(values) == 0:
            values = [["Global", "0", "0%"]]

        return tabulate.tabulate(values, headers=header, colalign=("left", "right", "right"))

    # NB: This context manager is NOT reentrant
    def __enter__(self):
        self.flop_counts.clear()
        self.mod_tracker.__enter__()
        self.mode = _FlopCounterMode(self)
        self.mode.__enter__()
        return self

    def __exit__(self, *args):
        assert self.mode is not None
        b = self.mode.__exit__(*args)
        self.mode = None  # break cycles
        self.mod_tracker.__exit__()
        if self.display:
            print(self.get_table(self.depth))
        return b

    def _count_flops(self, func_packet, out, args, kwargs):
        if func_packet in self.flop_registry:
            flop_count_func = self.flop_registry[func_packet]
            flop_count = flop_count_func(*args, **kwargs, out_val=out)  # type: ignore[operator]
            for par in set(self.mod_tracker.parents):
                self.flop_counts[par][func_packet] += flop_count

        return out
File:           /media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/torch/utils/flop_counter.py
Type:           type
Subclasses:     
import torch
import torchvision.models as models


with torch.device('meta'):
    inp = torch.randn((8, 3, 224, 224))
    mod = models.resnet18()
with FlopCounterMode() as flop_counter:
    mod(inp)
Module                   FLOP    % Total
--------------------  -------  ---------
ResNet                29.025B    100.00%
 - aten.convolution   29.017B     99.97%
 - aten.addmm          0.008B      0.03%
 ResNet.conv1          1.888B      6.51%
  - aten.convolution   1.888B      6.51%
 ResNet.fc             0.008B      0.03%
  - aten.addmm         0.008B      0.03%
 ResNet.layer1         7.399B     25.49%
  - aten.convolution   7.399B     25.49%
 ResNet.layer2         6.577B     22.66%
  - aten.convolution   6.577B     22.66%
 ResNet.layer3         6.577B     22.66%
  - aten.convolution   6.577B     22.66%
 ResNet.layer4         6.577B     22.66%
  - aten.convolution   6.577B     22.66%
import torch
import torchvision.models as models


with torch.device('cpu'):
    inp = torch.randn((8, 3, 224, 224))
    mod = models.resnet18()
with FlopCounterMode() as flop_counter:
    mod(inp)
Module                   FLOP    % Total
--------------------  -------  ---------
ResNet                29.025B    100.00%
 - aten.convolution   29.017B     99.97%
 - aten.addmm          0.008B      0.03%
 ResNet.conv1          1.888B      6.51%
  - aten.convolution   1.888B      6.51%
 ResNet.fc             0.008B      0.03%
  - aten.addmm         0.008B      0.03%
 ResNet.layer1         7.399B     25.49%
  - aten.convolution   7.399B     25.49%
 ResNet.layer2         6.577B     22.66%
  - aten.convolution   6.577B     22.66%
 ResNet.layer3         6.577B     22.66%
  - aten.convolution   6.577B     22.66%
 ResNet.layer4         6.577B     22.66%
  - aten.convolution   6.577B     22.66%
import torch
import torchvision.models as models


with torch.device('meta'):
    inp = torch.randn((8, 3, 224, 224))
    mod = models.resnet18()
with FlopCounterMode() as flop_counter:
    mod(inp)
Module                   FLOP    % Total
--------------------  -------  ---------
ResNet                29.025B    100.00%
 - aten.convolution   29.017B     99.97%
 - aten.addmm          0.008B      0.03%
 ResNet.conv1          1.888B      6.51%
  - aten.convolution   1.888B      6.51%
 ResNet.fc             0.008B      0.03%
  - aten.addmm         0.008B      0.03%
 ResNet.layer1         7.399B     25.49%
  - aten.convolution   7.399B     25.49%
 ResNet.layer2         6.577B     22.66%
  - aten.convolution   6.577B     22.66%
 ResNet.layer3         6.577B     22.66%
  - aten.convolution   6.577B     22.66%
 ResNet.layer4         6.577B     22.66%
  - aten.convolution   6.577B     22.66%
from torch.utils.flop_counter import FlopCounterMode
import torch
import torchvision.models as models


with torch.device('meta'):
    inp = torch.randn((1, 3, 224, 224))
    mod = models.resnet18()
with FlopCounterMode() as flop_counter:
    mod(inp).sum()
Module                     FLOP    % Total
--------------------  ---------  ---------
ResNet                3628.147M    100.00%
 - aten.convolution   3627.123M     99.97%
 - aten.addmm            1.024M      0.03%
 ResNet.conv1          236.028M      6.51%
  - aten.convolution   236.028M      6.51%
 ResNet.fc               1.024M      0.03%
  - aten.addmm           1.024M      0.03%
 ResNet.layer1         924.844M     25.49%
  - aten.convolution   924.844M     25.49%
 ResNet.layer2         822.084M     22.66%
  - aten.convolution   822.084M     22.66%
 ResNet.layer3         822.084M     22.66%
  - aten.convolution   822.084M     22.66%
 ResNet.layer4         822.084M     22.66%
  - aten.convolution   822.084M     22.66%
print(f"{flop_counter.get_total_flops()/1e9:.2f} GFLOPs")
3.63 GFLOPs