from typing import List, Union
import copy
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
from torchvision.models import vgg16, VGG16_Weights
from torchvision import transforms, datasets
from tqdm import tqdm
Structured Pruning with CNNs
In the previous part of this blog series, we covered the basics of pruning, specifically fine-grained pruning. If you haven’t had a chance to read it yet, I encourage you to check out that notebook to familiarize yourself with the foundational concepts and understanding the motivation behind what we will do now.
In this section, we will review the different types of pruning methods, focusing particularly on channel-wise pruning. Pruning can range from fine-grained approaches, where individual connections are zeroed out based on their importance, to more structured forms like channel-wise pruning. Channel-wise pruning is a more regular method that involves removing entire channels or layers from the network.
Similar to fine-grained pruning, we will use magnitudes as a heuristic measure of importance to guide our pruning decisions. However, the key advantage of channel-wise pruning over fine-grained approaches is its structure. By removing entire channels, we achieve a more organized pruning process, which doesn’t rely heavily on specialized hardware for efficient computation. This is because channel-wise pruning involves explicitly slicing out chunks of the parameter tensors, making it more straightforward and potentially more compatible with standard hardware.
This article has been inspired from the labs of the EfficientML.ai course by MIT. In the following sections, we will delve into the specifics of channel-wise pruning and how it can be effectively implemented to achieve significant model compression while maintaining computational efficiency. This will be similar to the methodology followed in Han et al. (2015).
Setup
We will be using the same setup as the previous blog: a VGG-16 architecture that we will prune and evaluate on the CIFAR-10 dataset.
Code
= VGG16_Weights.DEFAULT
weights = vgg16(weights=weights)
model
= 64
batch_size
= "./data"
root = datasets.CIFAR10(root=root,
train_ds =True,
train=True,
download=weights.transforms())
transform= datasets.CIFAR10(root=root,
test_ds =False,
train=True,
download=weights.transforms())
transform= DataLoader(train_ds,
train_dl =batch_size,
batch_size=True)
shuffle= DataLoader(test_ds,
test_dl =batch_size,
batch_size=False)
shuffle
def get_num_parameters(model: nn.Module, count_nonzero_only=False) -> int:
"""
calculate the total number of parameters of model
:param count_nonzero_only: only count nonzero weights
"""
= 0
num_counted_elements for param in model.parameters():
if count_nonzero_only:
+= param.count_nonzero()
num_counted_elements else:
+= param.numel()
num_counted_elements return num_counted_elements
def get_model_size(model: nn.Module, data_width=32, count_nonzero_only=False) -> int:
"""
calculate the model size in bits
:param data_width: #bits per element
:param count_nonzero_only: only count nonzero weights
"""
return get_num_parameters(model, count_nonzero_only) * data_width
def train_step(model, dataloader, criterion, optimizer, device):
model.train()
= 0.
train_loss = 0.
train_acc
for step, (X, y) in tqdm(enumerate(dataloader), desc="Training", leave=False):
= X.to(device), y.to(device)
X, y
= model(X)
logits = criterion(logits, y)
loss
+= loss.item()
train_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
= torch.argmax(logits.detach(), dim=1)
y_pred += ((y_pred == y).sum().item() / len(y))
train_acc
= train_loss / len(dataloader)
train_loss = train_acc / len(dataloader)
train_acc return train_loss, train_acc
@torch.inference_mode()
def eval_step(model, dataloader, criterion, device):
eval()
model.
= 0.
eval_loss = 0.
eval_acc
for (X, y) in tqdm(dataloader, desc="Evaluating", leave=False):
= X.to(device), y.to(device)
X, y
= model(X)
logits = criterion(logits, y)
loss
+= loss.item()
eval_loss
= torch.argmax(logits.detach(), dim=1)
y_pred += ((y_pred == y).sum().item() / len(y))
eval_acc
= eval_loss / len(dataloader)
eval_loss = eval_acc / len(dataloader)
eval_acc return eval_loss, eval_acc
= nn.CrossEntropyLoss()
criterion = "cuda" if torch.cuda.is_available() else "cpu"
device
model.to(device)
= "vgg16.pth"
ckpt_path model.load_state_dict(torch.load(ckpt_path))
# Get original model size and benchmark accuracy
= eval_step(model, test_dl, criterion, device)
val_loss, orig_acc = 8 * 1024**2
MB = get_model_size(model) / MB
orig_model_size_mb
print(f"Original model accuracy: {orig_acc:.2f}")
print(f"Original model size: {orig_model_size_mb:.2f} MB")
Original model accuracy: 0.90
Original model size: 527.79 MB
Channel-wise Pruning
To begin with channel-wise pruning, it’s essential to understand the structure of the weight tensor in a convolutional block. A convolutional layer typically has a weight tensor with dimensions corresponding to the number of input channels, output channels, and the kernel size.
When considering channel-wise pruning, we need to determine which axis of this tensor we will be pruning along. Specifically, our goal is to identify and remove the less important channels from either the input or output of the convolutional layer. We will see later what considerations we have to keep in mind.
# Examine the structure of a Convolution layer
= 32
in_chans = 64
out_chans = 3
kernel_size = nn.Conv2d(in_channels=in_chans,
conv_layer =out_chans,
out_channels=kernel_size)
kernel_size
= conv_layer.weight.data
conv_weight
print(f"In Channels: {in_chans}")
print(f"Out Channels: {out_chans}")
print(f"Kernel Size: {kernel_size}")
print(f"Shape of Conv Layer weight tensor: {conv_weight.shape}")
In Channels: 32
Out Channels: 64
Kernel Size: 3
Shape of Conv Layer weight tensor: torch.Size([64, 32, 3, 3])
To delve deeper into channel-wise pruning, it’s crucial to examine the weight tensor’s structure more closely. For a convolutional layer, the weight tensor typically has the shape \((c_{out}, c_{in}, k_h, k_w)\), where \(c_{out}\) and \(c_{in}\) denote the number of output and input channels respectively, and \(k_h\) and \(k_w\) represent the height and width of the kernels.
When considering channel-wise pruning, our primary focus is on the second axis of the weight tensor — the \(c_{in}\) dimension, which represents the input channels. Pruning channels involves slicing out certain portions along this axis. However, it’s important to recognize that removing channels from this dimension will alter the shape of the activation maps that pass through the convolutional layer.
To effectively implement channel-wise pruning, we need to understand how the changes to one layer’s weight tensor will impact subsequent layers in the network. Specifically, if we prune input channels, the number of input channels to subsequent layers will be affected. This alteration requires us to adjust the subsequent layers accordingly to ensure that they can still process the modified activations.
Examining the VGG architecture, we can develop a strategy for pruning that takes these dependencies into account. For example, VGG networks consist of a series of convolutional layers followed by fully connected layers. If we prune channels in one convolutional layer, we must update the subsequent layers to match the new input dimensions and ensure that the network remains functional.
In summary, while pruning channels from the second axis of the weight tensor is straightforward, the resulting changes in activation shapes necessitate careful consideration of the network’s overall architecture. By understanding and adjusting for these changes, we can effectively apply channel-wise pruning while maintaining the integrity of the model.
# Examine model architecture
model
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
The VGG architecture features alternating blocks of convolutional layers, ReLUs, and MaxPooling operations. To apply channel-wise pruning effectively, we should focus on each pair of adjacent convolutional blocks.
For each pair, we adjust the output channels of the preceding convolutional block and the input channels of the subsequent convolutional block. This approach ensures that the changes made by pruning are consistent throughout the network, maintaining the correct flow of data and preserving model functionality.
# Dummy example
class ConvBlock(nn.Module):
def __init__(self, in_chans, hidden_size, out_chans):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=in_chans, out_channels=hidden_size, kernel_size=3)
self.mp = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(in_channels=hidden_size, out_channels=out_chans, kernel_size=3)
def forward(self, x):
return self.conv2(self.mp(self.conv1(x)))
= 16
in_chans = 32
hidden_size = 64
out_chans = 28
img_size
= torch.randn(in_chans, img_size, img_size)
x = ConvBlock(in_chans, hidden_size, out_chans)
block
= block(x)
out = get_num_parameters(block)
block_orig_numparams print(f"Number of parameters in block: {block_orig_numparams}")
print(f"Output shape: {out.shape}")
Number of parameters in block: 23136
Output shape: torch.Size([64, 11, 11])
def get_num_channels_to_keep(channels: int, prune_ratio: float) -> int:
"""
A function to calculate the number of layers to PRESERVE after pruning
"""
return int(round((1-prune_ratio)*channels))
# Get the weights from the block
= block.conv1
conv1 = block.mp
mp = block.conv2
conv2
# Start pruning the channels
= 0.6
prune_ratio = conv1.out_channels
original_channels = get_num_channels_to_keep(original_channels, prune_ratio)
n_keep
# 1. Prune the output channels of the first convolution layer
with torch.no_grad():
= nn.Parameter(
conv1.weight
conv1.weight.detach()[:n_keep, ...]
)# Adjust the bias as well, if it exists
if conv1.bias is not None:
= nn.Parameter(
conv1.bias
conv1.bias.detach()[:n_keep]
)
# 2. Prune the affected input channels of the next convolution
with torch.no_grad():
= nn.Parameter(
conv2.weight
conv2.weight.detach()[:, :n_keep, ...]
)# Bias does not need to be adjusted for this layer
print(f"Number of parameters in block (post-pruning): {get_num_parameters(block)}")
print(f"Output shape: {out.shape}")
Number of parameters in block (post-pruning): 9437
Output shape: torch.Size([64, 11, 11])
After experimenting with a dummy example of pruning input and output channels within a convolutional block, we’ve observed a reduction in the number of parameters while maintaining the same output shape. However, it’s important to note that the number of parameters remaining is not precisely the product of the prune ratio and the total number of parameters, as we’re only modifying the channels and not the entire network.
To streamline this process, we’ll define a function that automates channel-wise pruning for the entire model. This function will handle adjusting the number of input and output channels for each convolutional layer, ensuring that the pruned network remains consistent and functional. This automated approach will simplify the pruning procedure and facilitate efficient model compression.
# Reference: https://hanlab.mit.edu/courses/2024-fall-65940
@torch.no_grad()
def channel_prune(model: nn.Module,
float]) -> nn.Module:
prune_ratio: Union[List, """
Apply channel pruning to each convolutional layer in the backbone.
:param model: The model to be pruned
:param prune_ratio: Either a single float for uniform pruning across
layers or a list of floats specifying per-layer pruning rates.
:return pruned_model: The model with pruned channels
"""
assert isinstance(prune_ratio, (float, list))
= [m for m in model.features if isinstance(m, nn.Conv2d)]
conv_layers = len(conv_layers)
n_conv
# The ratio affects the first conv's input channels, and the next one's out channels
if isinstance(prune_ratio, float):
= [prune_ratio] * (n_conv - 1)
prune_ratio else:
assert len(prune_ratio) == n_conv - 1, "prune_ratio list length must be one less than the number of Conv2d layers."
# Create a deepcopy so we don't modify the original
= copy.deepcopy(model)
pruned_model = [m for m in pruned_model.features if isinstance(m, nn.Conv2d)]
conv_layers
# Apply channel pruning to each pair of consecutive convolutional layers
for i, ratio in enumerate(prune_ratio):
= conv_layers[i]
prev_conv = conv_layers[i + 1]
next_conv = prev_conv.out_channels
prev_channels = get_num_channels_to_keep(prev_channels, ratio)
n_keep
with torch.no_grad():
# Prune the output channels of the previous convolution
= nn.Parameter(prev_conv.weight.detach()[:n_keep, ...])
prev_conv.weight if prev_conv.bias is not None:
= nn.Parameter(prev_conv.bias.detach()[:n_keep])
prev_conv.bias
# Prune the input channels of the next convolution
= nn.Parameter(next_conv.weight.detach()[:, :n_keep, ...])
next_conv.weight
print("Channel pruning completed. Note: The printed model structure may not reflect the pruned dimensions.")
return pruned_model
# Prune the model without considering the channel importances
= channel_prune(model, 0.4)
pruned_model_naive = get_model_size(pruned_model_naive) / MB
pruned_model_size_mb
print(f"Original model size: {orig_model_size_mb:.2f} MB")
print(f"Pruned model size: {pruned_model_size_mb:.2f} MB")
Channel pruning completed. Note: The printed model structure may not reflect the pruned dimensions.
Original model size: 527.79 MB
Pruned model size: 494.03 MB
# Evaluate the model after this crude pruning
= eval_step(pruned_model_naive, test_dl, criterion, device)
_, acc
print(f"Accuracy after naive pruning: {acc:.2f}")
Accuracy after naive pruning: 0.10
Pruning less-important channels
When considering channel-wise pruning, it’s crucial to avoid arbitrary removal of channels, as this could lead to significant performance degradation by discarding “important” channels, as was seen above. To address this, we need to prioritize channels based on their importance.
A practical approach to determining channel importance is to use the norms of the weight tensors as a measure. The idea is that channels with larger norms are generally more critical for the network’s performance. By calculating these norms, we can rank the channels and selectively prune the less important ones.
The process involves defining the importance of each channel, sorting them accordingly, and then retaining only the most important channels. We can integrate this approach into our existing pruning function, which previously kept the first n_keep
channels. By incorporating channel importance into this function, we can ensure that the pruning is more strategic, preserving the channels that contribute most significantly to the network’s performance.
# Grab a random convolution weight tensor to demonstrate computing channel importances
= model.features[2].weight
rand_weight print(f"Shape: {rand_weight.shape}")
def get_input_channel_importance(weight):
= weight.shape[1]
in_channels = []
importances
# Compute the importance for each input channel
for i_c in range(in_channels):
= weight.detach()[:, i_c] # (c_out, k, k)
channel_weight = torch.norm(channel_weight) # take the Frobenius norm
importance 1))
importances.append(importance.view(return torch.cat(importances)
print(f"Importances of the 64 input channels:\n{get_input_channel_importance(rand_weight)}")
Shape: torch.Size([64, 64, 3, 3])
Importances of the 64 input channels:
tensor([1.6980, 1.7297, 2.1539, 0.9993, 1.8365, 0.8037, 1.6834, 0.9603, 1.1872,
1.3582, 1.1134, 1.0917, 1.4087, 0.9050, 1.7213, 0.7823, 1.1833, 1.1185,
1.1565, 2.2874, 1.3824, 1.4115, 1.0174, 1.3120, 1.1977, 0.9108, 0.7976,
0.9291, 1.1520, 1.1238, 0.9578, 0.7938, 1.4062, 1.4817, 2.5130, 1.0180,
1.3782, 0.9571, 0.9826, 1.3465, 1.0445, 0.8921, 1.5498, 0.7251, 1.1079,
0.9550, 1.3082, 1.2728, 0.9647, 0.8078, 0.7796, 2.1746, 1.1919, 2.0185,
0.7407, 1.4707, 1.0315, 1.8911, 2.1096, 2.2035, 0.9893, 0.7218, 0.7914,
1.1358], device='cuda:0')
# Reference: https://hanlab.mit.edu/courses/2024-fall-65940
@torch.no_grad()
def apply_channel_sorting(model):
'''
Sorts the channels in decreasing order of importance for the given model
:param model: Model to apply the channel sorting to
'''
# Create a deep copy of the model to avoid modifying the original
= copy.deepcopy(model)
sorted_model
# Fetch all the convolutional layers from the backbone
= [m for m in sorted_model.features if isinstance(m, nn.Conv2d)]
conv_layers
# Iterate through the convolutional layers and sort channels by importance
for i in range(len(conv_layers) - 1):
= conv_layers[i]
prev_conv = conv_layers[i + 1]
next_conv
# Compute the importance of input channels for the next convolutional layer
= get_input_channel_importance(next_conv.weight)
importance = torch.argsort(importance, descending=True)
sort_idx
# Sort the output channels of the previous convolutional layer
= nn.Parameter(torch.index_select(prev_conv.weight.detach(), 0, sort_idx))
prev_conv.weight if prev_conv.bias is not None:
= nn.Parameter(torch.index_select(prev_conv.bias.detach(), 0, sort_idx))
prev_conv.bias
# Sort the input channels of the next convolutional layer
= nn.Parameter(torch.index_select(next_conv.weight.detach(), 1, sort_idx))
next_conv.weight
return sorted_model
We directly manipulate the slices of the corresponding channels within the weight and bias tensors. By sorting channels based on their importance in decreasing order, we can identify and retain the most critical channels as per our desired pruning ratio.
After sorting, we keep the first n_keep
slices, which ensures that the most important channels are preserved. This method leverages our previous function for pruning, allowing us to efficiently apply the revised approach. The rearrangement of tensor slices according to channel importance ensures that the pruned network maintains its effectiveness while reducing its size.
= 0.4 # pruned-out ratio
channel_pruning_ratio
print("Without sorting channels by importance...")
= channel_prune(model, channel_pruning_ratio)
pruned_model = eval_step(pruned_model, test_dl, criterion, device)
_, acc print(f"Pruned model accuracy: {acc:.2f}")
print('-'*25)
print("With sorting channels by importance...")
= apply_channel_sorting(model)
sorted_model = channel_prune(sorted_model, channel_pruning_ratio)
pruned_model = eval_step(pruned_model, test_dl, criterion, device)
_, acc print(f"Pruned model accuracy: {acc:.2f}")
Without sorting channels by importance...
Channel pruning completed. Note: The printed model structure may not reflect the pruned dimensions.
Pruned model accuracy: 0.10
-------------------------
With sorting channels by importance...
Channel pruning completed. Note: The printed model structure may not reflect the pruned dimensions.
Pruned model accuracy: 0.15
Recovering performance with Finetuning
With the sorting, we observe a somewhat smaller performance decrease compared to the initial approach, though the drop is still significant. This reduction in performance is typical with channel-wise pruning because entire chunks of the model are removed, which can affect its ability to generalize.
To mitigate the performance loss, we again employ fine-tuning. This step is essential to allow the model to adjust to the new, pruned structure and recover some of its original performance. Fine-tuning helps recalibrate the remaining parameters and optimize the network for its reduced size.
# Finetune to recover performance
= 1e-4
learning_rate = 3
epochs = nn.CrossEntropyLoss()
criterion = torch.optim.AdamW(pruned_model.parameters(), lr=learning_rate)
optimizer
for epoch in tqdm(range(epochs), desc="Epochs"):
= train_step(pruned_model, train_dl, criterion, optimizer, device)
train_loss, train_acc = eval_step(pruned_model, test_dl, criterion, device)
val_loss, val_acc print(f"Epoch {epoch+1}... Train Accuracy: {train_acc:.2f} | Validation Accuracy: {val_acc:.2f}")
Epochs: 33%|███▎ | 1/3 [02:14<04:29, 134.74s/it]Epochs: 67%|██████▋ | 2/3 [04:32<02:16, 136.25s/it]Epochs: 100%|██████████| 3/3 [06:47<00:00, 135.89s/it]
Epoch 1... Train Accuracy: 0.87 | Validation Accuracy: 0.85
Epoch 2... Train Accuracy: 0.94 | Validation Accuracy: 0.90
Epoch 3... Train Accuracy: 0.97 | Validation Accuracy: 0.90
= eval_step(pruned_model, test_dl, criterion, device)
_, acc
print(f"Validation accuracy of original (dense) model: {orig_acc:.3f}")
print(f"Final validation accuracy of pruned model: {acc:.3f}")
= get_model_size(pruned_model) / MB
pruned_model_size_mb
print(f"Original model size: {orig_model_size_mb:.2f} MB")
print(f"Pruned model size: {pruned_model_size_mb:.2f} MB")
Validation accuracy of original (dense) model: 0.898
Final validation accuracy of pruned model: 0.899
Original model size: 527.79 MB
Pruned model size: 494.03 MB
# Test the latency of the original and pruned models
= next(iter(train_dl))[0].cuda()
x
%timeit model(x)
%timeit pruned_model(x)
53.7 ms ± 85.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
35.2 ms ± 75.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
We can see from the outputs that the pruned model not only (virtually) matches the performance of the original model on the dataset, but is also actually smaller and takes lesser time for inference.
One of the notable advantages of channel-wise pruning is that it does not require special hardware to handle the zeros introduced during fine-grained pruning. Since channel-wise pruning involves removing entire channels rather than zeroing out individual connections, the pruned model becomes inherently more efficient. This structured approach results in a more compact model that requires less storage and computation, making it easier to deploy and operate without the need for hardware designed to optimize sparse matrices. Thus, channel-wise pruning not only helps reduce the model size but also simplifies the computational requirements for inference, leading to overall efficiency gains.