TENT: Fully Test-Time Adaptation By Entropy Minimization

An attempted (partial) paper reproduction

deep learning
paper
Published

December 29, 2024

Once a model is deployed the feature (covariate) data distribution might shift from that seen during training. These shifts make models go out-of-distribution and worsen their predictions. This paper proposes a simple method to help models adapt to these shifts: minimize the entropy of your predictions.

That is, before making test-time predictions for a batch, you nudge (SGD) the model to predict peakier (less entropic) class distributions.

Why minimize entropy?

Firstly, because it is convenient. In contrast to other methods, you don’t need to modify the training procedure nor require test-time labels. Because labels are rarely available at test time, this makes TENT “fully test-time”.

Second, the authors argue that entropy is related to both error and shifts:

“Entropy is related to error, as more confident predictions are all-in-all more correct (Figure 1). Entropy is related to shifts due to corruption, as more corruption results in more entropy, with a strong rank correlation to the loss for image classification as the level of corruption increases (Figure 2).”

To reproduce Figures 1 & 2 we train a ResNet on CIFAR-10 and evaluate its predictions on corrupted versions of the test set to simulate test-time shifts.

(Note: while the authors also show results for CIFAR-100 and ImageNet, we’ll only deal with this small dataset and model for convenience.)

Datasets
corruption_types = [
    'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur',
    'snow', 'frost', 'fog', 'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression'
]

# train_set = datasets.CIFAR10('../data', download = True, train = True,  transform = transforms.ToTensor())
test_set =  datasets.CIFAR10('../data', download = True, train = False, transform = transforms.ToTensor())
n_classes = len(test_set.classes)

def get_test_set(corr_type = 'brightness', data_path = 'data', severity = 5):
    assert 1 <= severity <= 5
    if corr_type in corruption_types:
        X_test, y_test = load_cifar10c(10_000, severity, data_path, False, [corr_type])
        return TensorDataset(X_test, y_test)
    return test_set

def get_cifar10_model(model_path = 'models/cifar10_pretrained'):
    try:
        return torch.load(model_path)
    except:
        m = load_model('Standard', 'models', 'cifar10', ThreatModel.corruptions)
        torch.save(m, model_path)
        return torch.load(model_path)


get_model = lambda: load_model('Standard', 'models', 'cifar10', ThreatModel.corruptions)
model = get_model()

4 of 15 corruption types included in CIFAR-10-C, shown at the highest severity (5/5) level
Reproduce Figs 1 & 2
model = get_model()

c, e, l = [], [], []
corruptions, severities = [], []

for corr, severity in itertools.product([None] + corruption_types, range(1, 6)): # 'gaussian_noise', 'gaussian_blur', 'jpeg_compression', 'snow'

    if corr is None and severity > 1: continue
    corrupted_test_set = get_test_set(corr_type = corr, severity = severity)
   
    test = DataLoader(corrupted_test_set, batch_size = 128, shuffle = False)

    model.to(device); model.eval()
    with torch.no_grad():
        for images, labels in test:
            images, labels = images.to(device), labels.to(device)
            logits = model(images)
            _, pred = torch.max(logits, 1)
            correct = (pred == labels).float()
            entropy = -(logits.softmax(dim = 1) * logits.log_softmax(dim = 1)).sum(dim = 1)
            loss = nn.CrossEntropyLoss(reduce = False)(logits, labels)
            c.append(correct)
            e.append(entropy)
            l.append(loss)
            corruptions.extend([corr] * len(correct))
            severities.extend([severity] * len(correct))

correct = torch.cat(c).cpu().numpy()
entropy = torch.cat(e).cpu().numpy()
loss    = torch.cat(l).cpu().numpy()
Figure 1: Preds with less entropy have lower error rates.
Figure 2: More corruption (shown as alpha) leads to higher loss and entropy.

The intuition here as far as I can tell is that entropy encodes the model’s confidence. If the model’s prediction is confident it is all-in-all more probable to be correct (it might have seen similar examples during training, the example might be “easy”, etc). Corruptions take the model OOD and decrease its confidence. Since cross-entropy is lowest when all probability mass is assigned to the correct label, increasing entropy (all-in-all) dilutes that mass and increases loss.

Two important notes on how entropy is minimized:

First, the authors note that once we switch the model to entropy minimization we run the risk of causing it to deviate from its training. While you could choose a sufficiently small learning rate or add KL regularization to alleviate this, the authors opt for freezing most of the model and only updating the learnable parameters in the batch norm layers.

Second, we must use batches. If we minimize single examples we’ll just assign all of the mass to the most likely class.

TENT example
# Only update BN layers
def prepare_for_test_time(module, reset_stats = True):
    if isinstance(module, nn.BatchNorm2d):
        if reset_stats: module.reset_running_stats()
        module.requires_grad = True
    else: module.requires_grad = False 

    for m in module.children(): prepare_for_test_time(m, reset_stats)


# Init the model & optimizer
model = get_cifar10_model(); model.to(device)
corr_test_set = get_test_set(corr_type = 'gaussian_noise', severity = 5)
model.apply(functools.partial(prepare_for_test_time, reset_stats = False))
optimizer = optim.AdamW(model.parameters(), lr = 0.00001)

# Get a batch of corrupted images
images, labels = next(iter(DataLoader(corr_test_set, batch_size = 128)))
images, labels = images.to(device), labels.to(device)

# Minimize entropy
model.train()
optimizer.zero_grad()
preds = model(images)
entropy = -(preds.softmax(dim = 1) * preds.log_softmax(dim = 1)).sum(dim = 1)
loss = entropy.mean()
loss.backward()
optimizer.step()

new_preds = model(images)

# Plot
f, axs = plt.subplots(1, 3, figsize = (8, 3))
ix = 89
axs[0].imshow(test_set[ix][0].permute(1, 2, 0).cpu().numpy())
axs[0].set_title(test_set.classes[test_set[ix][1]])
axs[0].set_xticks([]); axs[0].set_yticks([])

axs[1].imshow(images[ix].permute(1, 2, 0).cpu().numpy())
axs[1].set_title('corrupted')
axs[1].set_xticks([]); axs[1].set_yticks([])

order = torch.argsort(preds.softmax(dim = 1)[ix]).detach().cpu().numpy()[::-1]
rows = [{'class': i, 'prob':p, 'type': 'unadapted'} for i, p in enumerate(preds.softmax(dim = 1)[ix].detach().cpu().numpy()[order])]
rows.extend([{'class': i, 'prob':p, 'type': 'TENT'} for i, p in enumerate(new_preds.softmax(dim = 1)[ix].detach().cpu().numpy()[order])])
sns.barplot(x = 'class', y = 'prob', hue = 'type', data = pd.DataFrame(rows), ax = axs[2])
axs[2].set_title('class distribution')
axs[2].legend()
axs[2].set_xticks([]); axs[2].set_yticks([])

f.tight_layout()
Figure 3: TENT produces class distributions with less entropy, concentrating mass in fewer classes.

Now for evaluation. While the authors consider other baselines, for simplicity, we only compare TENT against the unadapted source model and a test-time normalization method (“Norm”) which just updates the BN statistics during testing.

Eval unadapted model
def eval_source(init_model_fn, severity = 5, batch_size = 128):
    model = init_model_fn(); model.to(device)
    results = {}
    for corr_type in corruption_types:
        corr_test_set = get_test_set(corr_type = corr_type, severity = severity)
        _, source_acc = eval_model(model, DataLoader(corr_test_set, batch_size = batch_size))
        results[corr_type] = 1 - source_acc
    return results

source_results = eval_source(get_model)
torch.save(source_results, 'logs/source_results')
Eval Norm
def reset_bn_stats(module):
    if isinstance(module, nn.BatchNorm2d):
        module.reset_running_stats()
    for m in module.children(): reset_bn_stats(m)


@torch.no_grad()
def eval_norm(init_model_fn, severity = 5, batch_size = 128, reset_stats = False, corr_types = None):

    results_acc = {}
    if corr_types is None: corr_types = [None] + corruption_types
    for corr_type in corr_types:
        print(corr_type)
        
        model = init_model_fn(); model.to(device)  # Re-init the model
        if reset_stats: model.apply(reset_bn_stats)
        corr_test_set = get_test_set(corr_type = corr_type, severity = severity)

        for i, (images, labels) in enumerate(DataLoader(corr_test_set, batch_size = batch_size)):
            images, labels = images.to(device), labels.to(device)

            # Update the BN stats
            model.train()
            preds = model(images.to(device))

            err = (torch.max(preds, 1)[1] != labels).float().sum().item() / labels.shape[0]
            results_acc[corr_type] = results_acc.get(corr_type, []) + [err]
            if i % 15 == 0: print(err)
        
    return results_acc

norm_results = eval_norm(get_cifar10_model, reset_stats = False)
torch.save(norm_results, 'logs/norm_results_all')
Eval TENT
# Only update BN layers
def prepare_for_test_time(module, reset_stats = True):
    if isinstance(module, nn.BatchNorm2d):
        if reset_stats: module.reset_running_stats()
        module.requires_grad = True
    else: module.requires_grad = False 

    for m in module.children(): prepare_for_test_time(m, reset_stats)

def eval_tent(init_model_fn, severity = 5, lr = 0.001, batch_size = 128, reset_stats = True, corr_types = None):

    results_acc = {}
    results_e = {}
    if corr_types is None: corr_types = [None] + corruption_types

    for corr_type in corr_types:
        print(corr_type)

        # Re-init the model & optimizer
        model = init_model_fn(); model.to(device)
        corr_test_set = get_test_set(corr_type = corr_type, severity = severity)
        model.apply(functools.partial(prepare_for_test_time, reset_stats = reset_stats))
        optimizer = optim.AdamW(model.parameters(), lr = lr)

        corr_test_set = get_test_set(corr_type = corr_type, severity = severity)

        for i, (images, labels) in enumerate(DataLoader(corr_test_set, batch_size = batch_size)):
            images, labels = images.to(device), labels.to(device)

            # Minimize entropy
            model.train()
            optimizer.zero_grad()
            preds = model(images)
            entropy = -(preds.softmax(dim = 1) * preds.log_softmax(dim = 1)).sum(dim = 1)
            loss = entropy.mean()
            loss.backward()
            optimizer.step()

            err = (torch.max(preds, 1)[1] != labels).float().sum().item() / labels.shape[0]
            results_acc[corr_type] = results_acc.get(corr_type, []) + [err]
            results_e[corr_type] = results_e.get(corr_type, []) + [loss.item()]
            if i % 15 == 0: print(err)

    return results_acc, results_e


# tent_acc, tent_entropy = eval_tent(get_cifar10_model, reset_stats = False, lr = 0.00001)
# torch.save(tent_acc, 'logs/tent_acc'); torch.save(tent_entropy, 'logs/tent_entropy')

# tent_acc_r, tent_entropy_r = eval_tent(get_model, reset_stats = True, lr = 0.00001)
# torch.save(tent_acc_r, 'logs/tent_acc_r'); torch.save(tent_entropy_r, 'logs/tent_entropy_r')
Figure 4: TENT & Norm consistently outperform the unadapted model, with TENT (lr = 1e-5, batch_size = 128) taking a slight lead.
Hyperparam grid
grid_results = {}
for lr, b in itertools.product([1e-4, 1e-5, 1e-6], range(4, 9)):
    batch_size = int(2 ** b)
    print(lr, batch_size)
    tent_acc, tent_entropy = eval_tent(get_cifar10_model, reset_stats = False, lr = lr, batch_size = batch_size, corr_types = ['gaussian_noise'])
    grid_results[(lr, batch_size)] = (tent_acc, tent_entropy)

torch.save(grid_results, 'logs/grid_results_all')

grid_results_norm = {}
for b in range(4, 9):
    batch_size = int(2 ** b)
    print(batch_size)
    acc = eval_norm(get_cifar10_model, reset_stats = False, batch_size = batch_size, corr_types = ['gaussian_noise'])
    grid_results_norm[(batch_size)] = acc

torch.save(grid_results_norm, 'logs/grid_results_norm')

The paper shows TENT having more of a lead on this dataset, but this is the best I could do.

How sensitive is it to hyperparameters? TENT has two: the test-time learning rate and batch size. We vary these and show results for the gaussian_noise corruption.

Figure 5: TENT consistently outperforms other methods across corruption types.
Figure 6: Mean entropy over test set

You can see that TENT seems quite sensitive to hyperparameters, which is a common challenge to all Test-Time Adaptation methods. There definitely seems to be an entropy sweet-spot – presumably specific to the dataset and shift – controlled by the learning rate and batch size.

Wrapping up, TENT is a compelling TTA method. You can use a pre-trained model and don’t require test-time labels. However, the technique is limited to classification, miscalibrates models (makes them overconfident) or online learning (we need batches), and is sensitive to hyperparameters.

All-in-all it was an interesting paper and introduced me to the test-time adaptation literature.