Batch Norm exercises

deep learning
exercises
Published

July 2, 2024

Where I attempt to solve the exercises in section 8.5 of the d2l book from scratch in pytorch (without using the d2l library).

Imports
try:
    import matplotlib.pyplot as plt
except:
    !pip3 install matplotlib
    import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os, itertools, time, random

device = 'cpu'

Context

The section introduces how to implement batch norm (BN) and some of the intuitions behind its effectiveness.

As a quick recap, batch norm layers apply the following transform to their inputs:

\[ BN(x) = \gamma \odot \frac{x-\hat \mu_B}{\hat \sigma_B} + \beta \]

Where \(\gamma\), \(\beta\) are learned and \(\mu_B\), \(\sigma_B\) are estimated using the input’s minibatch \(B\) during training.

I.e. batch norm first normalizes the input to have mean \(0\) and std \(1\), facilitating convergence during optimization.

However, since BN is typically applied before activation (at least traditionally), doing so will reduce the expressive power of the layer. For instance, as pointed out by the original paper, “normalizing the inputs of a sigmoid would constrain them to the linear regime of the nonlinearity.” Below, we plot a sigmoid and note that in the [-1, 1] range (where most of the normalized data would fall) it is essentially linear.

Plot sigmoid
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(6, 3))
t_constrained = np.linspace(-1, 1, 100)
t_full = np.linspace(-4, 4, 100)
sigmoid = lambda t:1/(1 + np.exp(-t))
ax1.plot(t_full, sigmoid(t_full))
ax1.plot(t_constrained, sigmoid(t_constrained), c = 'r')
ax2.plot(t_constrained, sigmoid(t_constrained), c = 'r')
ax1.set_ylabel('sigmoid(x)')
ax1.set_xlabel('x')
ax2.set_xlabel('x')
f.tight_layout()

So to maintain the layer’s expressive power (degrees of freedom), “we make sure that the transformation inserted in the network can represent the identity transform” and introduce \(\gamma\) and \(\beta\). So if it is optimal to leave the input unchanged, the network can learn to do so by setting \(\gamma = \sigma_B(x)\) and \(\beta = \bar x_B\). (This part was confusing as the section only mentions: “Next, we apply a scale coefficient and an offset to recover the lost degrees of freedom” - the paper provided clarification).

Finally, the second reason batch norm seems to help is the implicit regularization it provides by injecting noise into the training process. What noise? \(\hat \mu_B\) and \(\hat \sigma_B\) are (noisy) estimates calculated on a sample (the minibatch). Thus, the size of the minibatch \(|B|\) plays an important role: too small and the estimates are too high variance; too big and the estimates become too stable (noiseless).

Anyway, let’s get to the exercises.

Q1

Should we remove the bias parameter from the fully connected layer or the convolutional layer before the batch normalization? Why?

I believe we could remove the bias parameter from both the fully connected and convolution layers if BN is applied as described in the section: right after the fully connected / convolution layer but before the activation \(\phi\). Why? essentially BN is location invariant because it centers the minibatch at 0:

\[ BN_{\gamma, \beta}(x + \alpha) = BN_{\gamma, \beta'}(x) \]

Thus in the fully connected layer case: \[ \boldsymbol h = \phi(BN_{\gamma, \beta}(\boldsymbol{Wx + b})) = \phi(BN_{\gamma, \beta'}(\boldsymbol{Wx})) \]

In convolution layers, we apply BN per channel, across all locations. I.e. “each channel has its own scale and shift parameters, both of which are scalars”. And since the convolution layer also outputs a scalar bias per channel, a similar argument applies.

But should we remove the biases? Yes. We get the same expressive power with fewer params.

Let’s try it out empirically on MNIST by training the BNLeNet network defined in the section, removing bias on linear and convolution layers, in turn. Although we’ll not get the same learned parameters (\(\beta \to \beta'\)), we should get comparable performance.

Model definition
def init_cnn(module):
    """Initialize weights for CNNs."""
    if type(module) in [nn.Linear, nn.Conv2d]:
        nn.init.xavier_uniform_(module.weight)


class BNLeNet(nn.Module):
    def __init__(self, num_classes=10, removed_bias = 'linear'):
        super().__init__()
        assert removed_bias in ['linear', 'conv', 'none']
        self.net = nn.Sequential(
            nn.LazyConv2d(6, kernel_size=5, bias = removed_bias == 'conv'), nn.LazyBatchNorm2d(),
            nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
            nn.LazyConv2d(16, kernel_size=5, bias = removed_bias == 'conv'), nn.LazyBatchNorm2d(),
            nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Flatten(), nn.LazyLinear(120, bias = removed_bias == 'linear'), nn.LazyBatchNorm1d(),
            nn.Sigmoid(), nn.LazyLinear(84, bias = removed_bias == 'linear'), nn.LazyBatchNorm1d(),
            nn.Sigmoid(), nn.LazyLinear(num_classes))
        
    def forward(self, x):
        return self.net(x)
Training function
def train(net, train_loader, test_loader, num_epochs = 5, lr = 0.1, verbose = True):
    # Infer input shapes, initialize weights and move to device
    _ = net(next(iter(train_loader))[0]) # Necessary before initing weights
    net.apply(init_cnn)
    net.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=lr)
    
    test_accs = []
    for epoch in range(num_epochs):
        net.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
        eval_acc = eval_net(net, test_loader)
        test_accs.append(eval_acc)
        if verbose: print(f'Epoch {epoch + 1}/{num_epochs}, Test acc: {eval_acc}')
        
    return test_accs
Train and evaluate
# Set data loader seed for reproducibility
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

test_loader = DataLoader(testing_data, batch_size = 128, shuffle = False)

for removed_bias in ['none', 'linear', 'conv']:

    torch.manual_seed(0); np.random.seed(0); random.seed(0)
    g = torch.Generator()
    g.manual_seed(0)
    net = BNLeNet(removed_bias = removed_bias)
    train_loader = DataLoader(training_data, batch_size = 128, shuffle = True, worker_init_fn = seed_worker, generator = g)

    train(net, train_loader, test_loader, num_epochs = 10, lr = 0.1, verbose = False)
    print(f'{removed_bias} final test acc: \t{eval_net(net, test_loader)}')
none final test acc:    0.9821
linear final test acc:  0.9869
conv final test acc:    0.9824

Close enough.

Q2

Compare the learning rates for LeNet with and without batch normalization.

  1. Plot the increase in validation accuracy.
  2. How large can you make the learning rate before the optimization fails in both cases?
Model definitions
class LeNet(nn.Module):

    def __init__(self, num_classes=10):
        super().__init__()
        self.net = nn.Sequential(
            nn.LazyConv2d(6, kernel_size=5, padding=2), nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.LazyConv2d(16, kernel_size=5), nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.LazyLinear(120), nn.Sigmoid(),
            nn.LazyLinear(84), nn.Sigmoid(),
            nn.LazyLinear(num_classes)
        )

    def forward(self, X):
        return self.net(X)
        
class BNLeNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.net = nn.Sequential(
            nn.LazyConv2d(6, kernel_size=5), nn.LazyBatchNorm2d(), nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.LazyConv2d(16, kernel_size=5), nn.LazyBatchNorm2d(), nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.LazyLinear(120), nn.LazyBatchNorm1d(), nn.Sigmoid(),
            nn.LazyLinear(84), nn.LazyBatchNorm1d(), nn.Sigmoid(),
            nn.LazyLinear(num_classes))
        
    def forward(self, x):
        return self.net(x)

We tried \(lr \in \{4, 2, 1, 0.5, 0.1, 0.05\}\) as \(0.05\) and \(4\) are the lowest and highest rates where LeNet still trains (in 10 epochs).

Plot results
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4), sharey = True)

for ax, name, results in zip([ax1, ax2], ['Without BN', 'With BN'], [without_bn, with_bn]):
    for lr, accs in zip(lrs, results):
        ax.plot(accs, label=f'lr={lr}')
        ax.set_title(f'{name}')
        ax.set_xlabel('Epoch')

ax1.set_ylabel('Validation Accuracy')
ax2.legend()
f.tight_layout()

And we observe that applying batch norm helps early training (we achieve better performance earlier) and makes training more robust to learning rate selection.

Understanding Batch Normalization poses that batch norm’s main benefit is that it allows for greater learning rates by containing activation blowup (especially in later layers), which in turn biases the optimization to “flatter” minimas with better generalization.

It seems our small experiment aligns with the paper, even though our network is quite shallow.

Q3

Do we need batch normalization in every layer? Experiment with it.

The paper demonstrates that it is more beneficial in later layers. Let’s see if it is in our case and remove each batch norm layer in turn. We logged test accuracies and activations and tested 5 nets per configuration.

Redefine the model
class BNLeNet(nn.Module):
    def __init__(self, num_classes = 10, ex_bn_layers = []):
        # ex_bn_layers: list of BN layers to exclude (1, ..., 4)

        super().__init__()
        for i in ex_bn_layers: assert i in range(1, 5), 'There are only 4 BN layers'

        layers = [
            nn.LazyConv2d(6, kernel_size=5), nn.LazyBatchNorm2d(), nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.LazyConv2d(16, kernel_size=5), nn.LazyBatchNorm2d(), nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.LazyLinear(120), nn.LazyBatchNorm1d(), nn.Sigmoid(),
            nn.LazyLinear(84), nn.LazyBatchNorm1d(), nn.Sigmoid(),
            nn.LazyLinear(num_classes)
        ]

        bn_idx = [i for i, module in enumerate(layers) if 'BatchNorm' in str(module)]
        for i in ex_bn_layers: layers[bn_idx[i - 1]] = nn.Identity()

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)
Train and evaluate
train_loader = DataLoader(training_data, batch_size = 128, shuffle = True)
test_loader = DataLoader(testing_data, batch_size = 128, shuffle = False)
imgs, labels = next(iter(train_loader))

ex_bn_layers = [[], [1], [2], [3], [4]]
mag_layers = [2, 6, 11, 14]
results = []

for ex in ex_bn_layers:

    print(ex)

    for epochs in [2, 4, 8]:

        print(epochs)
    
        for i in range(5):

            net = BNLeNet(ex_bn_layers = ex)
            test_accs = train(net, train_loader, test_loader, num_epochs = epochs, lr = 0.1, verbose = False)

            tmp = {
                'Excluded BN layers': ex[0] if ex else 0,
                'Epochs': epochs,
                'Test accuracy': test_accs[-1],
                'Iter': i
            }
            
            for lyr in mag_layers:
                sub = net.net[:lyr]
                with torch.no_grad(): 
                    mag = (sub(imgs) ** 2).mean()
                    tmp[f'Mean mag at {lyr}'] = mag.item()
            
            results.append(tmp)

Does performance change?

Plot test accuracies
res = pd.DataFrame(results)
res['Excluded BN layers'].replace({0: 'None', 1: '2', 2: '6', 3: '11', 4: '14'}, inplace = True)

sns.barplot(data = res, x = 'Excluded BN layers', y = 'Test accuracy', hue = 'Epochs')
plt.ylim(0.5, 1)

It seems that removing the first batch norm (in the second layer) has the strongest hit on performance.

What about activations? Let’s plot the \(l_2\) norm of activations of layers that come right after batch norm layers for a single minibatch:

Plot mean magnitudes
f, axs = plt.subplots(1, 4, figsize=(16, 4))

for i, lyr in enumerate(mag_layers):
    sns.barplot(data = res, x = 'Excluded BN layers', y = f'Mean mag at {lyr}', ax = axs[i], hue='Epochs')
    axs[i].set_title(f'Layer {lyr}')
    axs[i].set_ylabel('')

axs[0].set_ylabel('Mean act L2 magnitude')
f.tight_layout()

We can definitely see an effect when we remove batch norm layers. In general, the activation magnitudes decrease in the removed layer. I.e. layer 2’s magnitudes when its batch norm is removed are lower than in the original network, and so on for the other layers. It also seems that other layers compensate for this decrease by increasing their magnitudes. Finally, as the paper points out, the effects appear stronger in early epochs.

Q4

Implement a “lite” version of batch normalization that only removes the mean, or alternatively one that only removes the variance. How does it behave?

We can freeze BatchNorm’s weight and bias params respectively:

Train and evaluate
for remove_mean, remove_var in [[True, True], [False, True], [True, False], [False, False]]:

    # All nets "see" the same data
    torch.manual_seed(0); np.random.seed(0); random.seed(0)
    g = torch.Generator()
    g.manual_seed(0)

    net = BNLeNet()
    for module in net.net:
        if 'BatchNorm' in str(module):
            if remove_mean:
                module.bias.requires_grad = False
            if remove_var:
                module.weight.requires_grad = False

    train_loader = DataLoader(training_data, batch_size = 128, shuffle = True, worker_init_fn = seed_worker, generator = g)

    test_accs = train(net, train_loader, test_loader, num_epochs = 10, lr = 0.1, verbose = False)
    print(f'Remove mean: {remove_mean} Remove var: {remove_var}\t\tfinal test acc: {test_accs[-1]}')
Remove mean: True Remove var: True      final test acc: 0.9674
Remove mean: False Remove var: True     final test acc: 0.9676
Remove mean: True Remove var: False     final test acc: 0.9796
Remove mean: False Remove var: False        final test acc: 0.98

It appears that only removing the variance has more of an effect in our case. Is this generally true? Haven’t found anything online yet.

Q5

Fix the parameters beta and gamma. Observe and analyze the results.

We can accomplish this by affine = False in BatchNorm layers.

Model definition
class BNLeNet(nn.Module):
    def __init__(self, num_classes = 10, affine_bn = True):
        super().__init__()
        self.net = nn.Sequential(
            nn.LazyConv2d(6, kernel_size=5), nn.LazyBatchNorm2d(affine = affine_bn), nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.LazyConv2d(16, kernel_size=5), nn.LazyBatchNorm2d(affine = affine_bn), nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.LazyLinear(120), nn.LazyBatchNorm1d(affine = affine_bn), nn.Sigmoid(),
            nn.LazyLinear(84), nn.LazyBatchNorm1d(affine = affine_bn), nn.Sigmoid(),
            nn.LazyLinear(num_classes))
        
    def forward(self, x):
        return self.net(x)
Train and evaluate
results = {'affine': [], 'non affine': []}
os.makedirs('nets', exist_ok = True)

for affine in [True, False]:
    aff = 'affine' if affine else 'non affine'
    for seed in range(5):

        # All nets "see" the same data
        torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
        g = torch.Generator()
        g.manual_seed(seed)

        net = BNLeNet(affine_bn = affine)

        train_loader = DataLoader(training_data, batch_size = 128, shuffle = True, worker_init_fn = seed_worker, generator = g)
        test_accs = train(net, train_loader, test_loader, num_epochs = 10, lr = 0.1, verbose = False)

        results[aff].append(test_accs)
        torch.save(net, f'nets/{aff}_{seed}.pt')

        print(f'Affine: {affine}\t\tSeed: {seed}\t\tfinal test acc: {test_accs[-1]}')
Plot validation accuracies
tmp = []
for aff, res in results.items():
    for i, test_accs in enumerate(res):
        for epoch, acc in enumerate(test_accs):
            tmp.append({'Affine': aff, 'Val accuracy': acc, 'Seed': i, 'Epoch': epoch + 1})
tmp = pd.DataFrame(tmp)
sns.lineplot(data = tmp, x = 'Epoch', y = 'Val accuracy', hue = 'Affine')

It seems performance is worse during early epochs but converges as training progresses. Let’s see the batch norm with affine=True layers learned weights and biases somewhat departed from their defaults (\(1\) and \(0\) respectively).

Plot weights and biases
layers = [1, 5, 10, 13]
weights = {l:[] for l in layers}
biases = {l:[] for l in layers}

for fname in os.listdir('nets'):
    if 'non' in fname: continue
    net = torch.load('nets/' + fname)
    for layer in layers:
        weights[layer].extend(net.net[layer].weight.detach().tolist())
        biases[layer].extend(net.net[layer].bias.detach().tolist())

f, axs = plt.subplots(2, 4, figsize=(12, 6))
for i, (layer, w) in enumerate(weights.items()):
    sns.histplot(w, ax = axs[0, i])
    axs[0, i].set_title(f'Layer {layer} BN weights')
    axs[0, i].axvline(1, c = 'r')

    sns.histplot(biases[layer], ax = axs[1, i])
    axs[1, i].set_title(f'Layer {layer} BN biases')
    axs[1, i].axvline(0, c = 'r')

    if i > 0: axs[0, i].set_ylabel(''); axs[1, i].set_ylabel('')

f.tight_layout()

It seems that they did, especially the weights.

Note: In Q4 and Q5 we removed various components of batch norm layers were removed and performance was compared. However, we used the same fixed learning rate \(0.1\) for all configurations. A more thorough analysis would have found the optimal learning rate for each configuration, as Appendix F of the paper does:

Q6

Can you replace dropout by batch normalization? How does the behavior change?

Q7

Research ideas: think of other normalization transforms that you can apply:

  1. Can you apply the probability integral transform?

  2. Can you use a full-rank covariance estimate? Why should you probably not do that?

  3. Can you use other compact matrix variants (block-diagonal, low-displacement rank, Monarch, etc.)?

  4. Does a sparsification compression act as a regularizer?

  5. Are there other projections (e.g., convex cone, symmetry group-specific transforms) that you can use?