Distilling the Knowledge in a Neural Network

Revisiting and implementing part of the classical paper

deep learning
paper
Published

January 28, 2025

Idea

This classic paper introduced distillation as a way of transferring knowledge from a big network teacher into a small one. The core observation is that we should use the big model’s output distribution as soft labels to train the small model.

Remember that in classification we measure the cross-entropy loss, given the predicted \(\hat y_c\) and correct \(y_c\) class probabilities of an example, by:

\[ L(\hat y,y) = -\sum_c y_c \log \hat y_c \]

To use soft labels we just set \(y = f_{\text{big}}(x)\).

These soft labels provide a much richer training signal for the smaller model, especially when the larger model distributes its probability mass across multiple classes (i.e. when the labels have high entropy). To force this high entropy, the authors propose increasing the temperature \(T\) of the softmax layer in the larger model to produce the soft labels. The small model trains with this same temperature but then sets it to 1 during testing.

Increasing the temperature of the big model produces softer and more informative labels.

They also had better results by adding a small term to the loss function with the regular hard-labeled cross-entropy. The reasoning is that the model may not have enough capacity to learn the soft targets, so “erring in the direction of the correct answer turns out to be helpful”. If we write the output of a model with temperature \(T\) as \(f(x; T)\), then the complete loss is

\[ L_{\text{distill}}(x,y) = a T^2 \cdot L\left[ f_{\text{small}}(x; T), f_{\text{big}}(x; T) \right] + (1-a) \cdot L \left [ f_{\text{small}}(x; 1), y \right ] \]

The first term is scaled by \(T^2\) because the magnitudes of the gradients scale as \(T^{-2}\) and we want to control the contribution of each term by changing only \(a\).

MNIST

We try out distillation on the small-scale MNIST experiment that the authors describe. They use a two-layer linear ReLU architecture with dropout, a jitter image augmentation, and max norm as regularization.

Model definition
class Model(nn.Module):
    '''
    Used in MNIST experiments.
    A two-layer linear ReLU network with dropout and max norm regularization.
    '''
    def __init__(self, hidden_size, max_norm = 2.0, drop_rate = 0.5):
        super(Model, self).__init__()
        self.max_norm = max_norm
        self.layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, hidden_size),
            nn.ReLU(),
            nn.Dropout(drop_rate),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(drop_rate),
            nn.Linear(hidden_size, 10)
        )

    def forward(self, x):
        # Clip the weights to the maximum allowed norm
        if self.max_norm is not None:
            with torch.no_grad():
                for layer in self.modules():
                    if isinstance(layer, nn.Linear):
                        norm = layer.weight.data.norm(2, dim=1, keepdim=True)
                        desired = torch.clamp(norm, max=self.max_norm)
                        layer.weight.data *= (desired / norm)
                    
        return self.layers(x)

We define the distillation loss:

Define training losses
# Regular cross-entropy loss
def hard_loss(outputs, labels, criterion, *args):
    return criterion(outputs, labels)

# Distillation loss
def soft_loss(outputs, labels, criterion, examples, big_model, T, a):
    with torch.no_grad():
        big_model.eval()
        soft_labels = F.softmax(big_model(examples) / T, dim = 1)

    return a * (T ** 2) * criterion(outputs / T, soft_labels) + (1 - a) * criterion(outputs, labels)

The hidden dimensions of the big and small networks are 1200 and 800 respectively. To train the networks we use an early stopping validation set and choose \(T = 4.0\) and \(a = 0.5\) (since the authors don’t mention their values).

Train the big model
# Hyperparameters
num_epochs = 100
batch_size = 128
lr = 0.01
patience = 7
big_model_size = 1200
small_model_size = 800

loader = lambda ds, shuffle = False: DataLoader(ds, batch_size = batch_size, shuffle = shuffle)
val_loader   = loader(val_dataset)
test_loader  = loader(test_dataset)

# Train the big model
train_dataset.dataset.transform = aug_transform
train_loader = loader(train_dataset, shuffle = True)

big_model = Model(1200).to(device)
big_train_history = train(big_model, hard_loss, train_loader, val_loader, num_epochs, lr, patience)
test_loss, test_accuracy = evaluate_model(big_model, test_loader)

save_results(big_model, big_train_history, 'big_model')
test_loss, test_accuracy
Train the smaller model on hard labels
def train_small_model(train_dataset, val_dataset, seed, loss, model_size = small_model_size):
    set_seed(seed)
    # no augmentation
    train_dataset.dataset.transform = reg_transform
    train_loader = loader(train_dataset, shuffle = True)
    val_loader = loader(val_dataset)

    # or regularization
    small_model = Model(model_size, max_norm = None, drop_rate = 0.0).to(device)
    small_train_history = train(small_model, loss, train_loader, val_loader, num_epochs, lr, patience)
    return small_model, small_train_history

# small_model, small_train_history = train_small_model(train_dataset, val_dataset, seed = 42, loss = hard_loss, T = 1.0)
# save_results(small_model, small_train_history, 'small_model')
# evaluate_model(small_model, test_loader)
Train the distilled model
temperature, a = 4.0, 0.5

distilled_model, distilled_train_history = train_small_model(
    train_dataset = train_dataset, val_dataset = val_dataset, seed = 42,
    loss = functools.partial(soft_loss, big_model = big_model, T = temperature, a = a)
)
save_results(distilled_model, distilled_train_history, 'distilled_model')
evaluate_model(distilled_model, test_loader)

And get the following test accuracies:

big: 0.9901, small: 0.9833, distilled: 0.9891

Mystical 3

The authors then remove 3 from the transfer set the distilled model is trained on to test its generalization to unseen classes. “So from the perspective of the distilled model, 3 is a mythical digit that it has never seen”. When we evaluate on the test set, which still contains 3s, we see that the distilled model performs much better than a small model trained with hard labels:

Train without 3s in transfer set
# Remove all 3s from the dataset
train_dataset = datasets.MNIST(root = DATA_DIR, train = True,  download = True, transform = aug_transform)
train_dataset = torch.utils.data.Subset(train_dataset, np.where(train_dataset.targets != 3)[0])

# Split training data into train and validation sets
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    train_dataset, [train_size, val_size]
)

# Train small without distillation
small_no_3, small_no_3_history = train_small_model(
    train_dataset = train_dataset, val_dataset = val_dataset,
    seed = 42, loss = hard_loss, model_size = 800
)

# Train small with distillation
distilled_no_3, distilled_no_3_history = train_small_model(
    train_dataset = train_dataset, val_dataset = val_dataset,
    seed = 42, loss = functools.partial(soft_loss, big_model = big_model, T = 4.0, a = 0.5),
    model_size = 800
)
Not distilled: 0.8882, distilled: 0.9869

In the paper, the authors take it to the extreme and show that a distilled model trained only on 7 and 8 still achieves impressive performance. They also do experiments on a bigger speech recognition dataset and discuss training experts on a CV dataset with distillation from a generalist model as regularization.

Final thoughts

It was very fun to return to this classic paper. It introduced a simple yet powerful idea that is still widely used today. Like most of these papers (circa 2015), it is very clear and readable. And—as Hinton staple—it is slightly bio-inspired, in this case by larvae.

Some pointers to papers that extended on this idea. Self-distillation makes the teacher (“big”) and student (“small”) models the same size, and in mutual learning two or more networks learn collaboratively. However, the main extensions of this paper build on its main theme: train on a richer signal. You might train the student to imitate the teacher’s intermediate (or last) representations, attention maps, etc.