模型剪枝#

原作者: Michela Paganini

最先进的深度学习技术依赖于难以部署的过度参数化模型(over-parametrized models)。相反,已知生物神经网络使用高效的稀疏连接(sparse connectivity)。为了在不牺牲精度的情况下减少内存、电池和硬件的消耗,在设备上部署轻量级模型,并通过私有设备上的计算保证隐私性,确定通过减少模型中的参数数量来压缩模型的最佳技术是很重要的。在研究方面,剪枝(pruning)被用于研究过度参数化(over-parametrized)和欠参数化(under-parametrized)网络之间学习动态的差异,研究 lucky 稀疏子网络和初始化(“lottery tickets”)作为破坏性(destructive)神经结构搜索技术的作用,等等。

目标

学习如何使用 torch.nn.utils.prune 来稀疏化您的神经网络,以及如何扩展它来实现您自定义剪枝技术。

import torch
from torch import nn
from torch.nn.utils import prune
import torch.nn.functional as F

构建模型#

下面以 LeNet([Lecun et al., 1998])为例子。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

检查 Module#

检查 LeNet 模型中的(未修剪的)conv1 层。它将包含两个参数 weightbias,目前没有缓冲区(buffers)。

module = model.conv1
print(list(module.named_parameters()))
[('weight', Parameter containing:
tensor([[[[ 0.1876,  0.0330,  0.2159],
          [-0.0102, -0.1868, -0.0845],
          [-0.3198,  0.0996, -0.0450]]],


        [[[ 0.2758, -0.0678,  0.1061],
          [ 0.0937, -0.0211, -0.3151],
          [-0.2742, -0.2375, -0.2396]]],


        [[[-0.2714,  0.1406,  0.1719],
          [-0.0776, -0.0367,  0.3224],
          [-0.2366, -0.2433, -0.1543]]],


        [[[ 0.0750,  0.0104, -0.2353],
          [ 0.3158,  0.2640,  0.0445],
          [-0.0940,  0.0760,  0.0505]]],


        [[[ 0.2400, -0.0794,  0.0088],
          [-0.2652, -0.0784,  0.2804],
          [-0.1599, -0.1666, -0.0620]]],


        [[[ 0.0119, -0.1247, -0.0289],
          [-0.1928, -0.1183,  0.0925],
          [ 0.1004,  0.0195, -0.0648]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.1226, -0.0847, -0.1435, -0.1236, -0.2621, -0.2857], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
[]

剪枝 Module#

要剪枝 Module(在本例中是 LeNet 架构的 conv1 层),首先从 torch.nn.utils.prune 中选择一种剪枝技术(或者通过子类化 BasePruningMethod 实现自己的剪枝技术)。然后,指定要在该 module 中删除的 module 和参数的名称。最后,使用所选剪枝技术所需的适当关键字参数,指定剪枝参数。

在本例中,将在 conv1 层中随机删除名为 weight 的参数中的 \(30\%\) 的连接。module 作为函数的第一个参数传递;name 使用它的字符串标识符标识 module 中的参数;amount 表示要修剪的连接的百分比(如果是 0. 与 1. 之间的浮点数),或要修剪的连接的绝对数量(如果它是非负整数)。

prune.random_unstructured(module, name="weight", amount=0.3)
Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

修剪的方法是从参数中移除 weight,并用名为 weight_orig 的新参数替换它(即在初始参数 name 后追加 "_orig")。weight_orig 存储了张量的未修剪版本。bias 没有被剪除,所以它将保持不变。

print(list(module.named_parameters()))
[('bias', Parameter containing:
tensor([ 0.1226, -0.0847, -0.1435, -0.1236, -0.2621, -0.2857], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.1876,  0.0330,  0.2159],
          [-0.0102, -0.1868, -0.0845],
          [-0.3198,  0.0996, -0.0450]]],


        [[[ 0.2758, -0.0678,  0.1061],
          [ 0.0937, -0.0211, -0.3151],
          [-0.2742, -0.2375, -0.2396]]],


        [[[-0.2714,  0.1406,  0.1719],
          [-0.0776, -0.0367,  0.3224],
          [-0.2366, -0.2433, -0.1543]]],


        [[[ 0.0750,  0.0104, -0.2353],
          [ 0.3158,  0.2640,  0.0445],
          [-0.0940,  0.0760,  0.0505]]],


        [[[ 0.2400, -0.0794,  0.0088],
          [-0.2652, -0.0784,  0.2804],
          [-0.1599, -0.1666, -0.0620]]],


        [[[ 0.0119, -0.1247, -0.0289],
          [-0.1928, -0.1183,  0.0925],
          [ 0.1004,  0.0195, -0.0648]]]], device='cuda:0', requires_grad=True))]

The pruning mask generated by the pruning technique selected above is saved as a module buffer named weight_mask (i.e. appending "_mask" to the initial parameter name).

print(list(module.named_buffers()))

For the forward pass to work without modification, the weight attribute needs to exist. The pruning techniques implemented in torch.nn.utils.prune compute the pruned version of the weight (by combining the mask with the original parameter) and store them in the attribute weight. Note, this is no longer a parameter of the module, it is now simply an attribute.

print(module.weight)

Finally, pruning is applied prior to each forward pass using PyTorch’s forward_pre_hooks. Specifically, when the module is pruned, as we have done here, it will acquire a forward_pre_hook for each parameter associated with it that gets pruned. In this case, since we have so far only pruned the original parameter named weight, only one hook will be present.

print(module._forward_pre_hooks)

For completeness, we can now prune the bias too, to see how the parameters, buffers, hooks, and attributes of the module change. Just for the sake of trying out another pruning technique, here we prune the 3 smallest entries in the bias by L1 norm, as implemented in the l1_unstructured pruning function.

prune.l1_unstructured(module, name="bias", amount=3)

We now expect the named parameters to include both weight_orig (from before) and bias_orig. The buffers will include weight_mask and bias_mask. The pruned versions of the two tensors will exist as module attributes, and the module will now have two forward_pre_hooks.

print(list(module.named_parameters()))
print(list(module.named_buffers()))
print(module.bias)
print(module._forward_pre_hooks)

Iterative Pruning#

The same parameter in a module can be pruned multiple times, with the effect of the various pruning calls being equal to the combination of the various masks applied in series. The combination of a new mask with the old mask is handled by the PruningContainer’s compute_mask method.

Say, for example, that we now want to further prune module.weight, this time using structured pruning along the 0th axis of the tensor (the 0th axis corresponds to the output channels of the convolutional layer and has dimensionality 6 for conv1), based on the channels’ L2 norm. This can be achieved using the ln_structured function, with n=2 and dim=0.

prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# As we can verify, this will zero out all the connections corresponding to 
# 50% (3 out of 6) of the channels, while preserving the action of the 
# previous mask.
print(module.weight)

The corresponding hook will now be of type torch.nn.utils.prune.PruningContainer, and will store the history of pruning applied to the weight parameter.

for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))  # pruning history in the container

Serializing a pruned model#

All relevant tensors, including the mask buffers and the original parameters used to compute the pruned tensors are stored in the model’s state_dict and can therefore be easily serialized and saved, if needed.

print(model.state_dict().keys())

Remove pruning re-parametrization#

To make the pruning permanent, remove the re-parametrization in terms of weight_orig and weight_mask, and remove the forward_pre_hook, we can use the remove functionality from torch.nn.utils.prune. Note that this doesn’t undo the pruning, as if it never happened. It simply makes it permanent, instead, by reassigning the parameter weight to the model parameters, in its pruned version.

Prior to removing the re-parametrization:

print(list(module.named_parameters()))
print(list(module.named_buffers()))
print(module.weight)

After removing the re-parametrization:

prune.remove(module, 'weight')
print(list(module.named_parameters()))
print(list(module.named_buffers()))

Pruning multiple parameters in a model#

By specifying the desired pruning technique and parameters, we can easily prune multiple tensors in a network, perhaps according to their type, as we will see in this example.

new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers 
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers 
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

Global pruning#

So far, we only looked at what is usually referred to as “local” pruning, i.e. the practice of pruning tensors in a model one by one, by comparing the statistics (weight magnitude, activation, gradient, etc.) of each entry exclusively to the other entries in that tensor. However, a common and perhaps more powerful technique is to prune the model all at once, by removing (for example) the lowest 20% of connections across the whole model, instead of removing the lowest 20% of connections in each layer. This is likely to result in different pruning percentages per layer. Let’s see how to do that using global_unstructured from torch.nn.utils.prune.

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

Now we can check the sparsity induced in every pruned parameter, which will not be equal to 20% in each layer. However, the global sparsity will be (approximately) 20%.

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

Extending torch.nn.utils.prune with custom pruning functions#

To implement your own pruning function, you can extend the nn.utils.prune module by subclassing the BasePruningMethod base class, the same way all other pruning methods do. The base class implements the following methods for you: __call__, apply_mask, apply, prune, and remove. Beyond some special cases, you shouldn’t have to reimplement these methods for your new pruning technique. You will, however, have to implement __init__ (the constructor), and compute_mask (the instructions on how to compute the mask for the given tensor according to the logic of your pruning technique). In addition, you will have to specify which type of pruning this technique implements (supported options are global, structured, and unstructured). This is needed to determine how to combine masks in the case in which pruning is applied iteratively. In other words, when pruning a pre-pruned parameter, the current prunining techique is expected to act on the unpruned portion of the parameter. Specifying the PRUNING_TYPE will enable the PruningContainer (which handles the iterative application of pruning masks) to correctly identify the slice of the parameter to prune.

Let’s assume, for example, that you want to implement a pruning technique that prunes every other entry in a tensor (or – if the tensor has previously been pruned – in the remaining unpruned portion of the tensor). This will be of PRUNING_TYPE='unstructured' because it acts on individual connections in a layer and not on entire units/channels ('structured'), or across different parameters ('global').

class FooBarPruningMethod(prune.BasePruningMethod):
    """Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0 
        return mask

Now, to apply this to a parameter in an nn.Module, you should also provide a simple function that instantiates the method and applies it.

def foobar_unstructured(module, name):
    """Prunes tensor corresponding to parameter called `name` in `module`
    by removing every other entry in the tensors.
    Modifies module in place (and also return the modified module) 
    by:
    1) adding a named buffer called `name+'_mask'` corresponding to the 
    binary mask applied to the parameter `name` by the pruning method.
    The parameter `name` is replaced by its pruned version, while the 
    original (unpruned) parameter is stored in a new parameter named 
    `name+'_orig'`.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (string): parameter name within `module` on which pruning
                will act.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input
            module
    
    Examples:
        >>> m = nn.Linear(3, 4)
        >>> foobar_unstructured(m, name='bias')
    """
    FooBarPruningMethod.apply(module, name)
    return module

Let’s try it out!

model = LeNet()
foobar_unstructured(model.fc3, name='bias')

print(model.fc3.bias_mask)