Estimating and maximizing information with respect to multiple targets in Deep Representations

In this notebook, we showcase how torch_mist can be used for training neural representations, and quantitative/qualitative evaluation purposes. This example focuses on measuring an optimizing how much information a deep resnet18 representation contains about the position of eyes, nose and mouth in the CelebA pictures.

By estimating mutual information with a simple generative estimator, we demonstrate that a resnet18 architecture pretrained on ImageNet contains more information than a randomly initialized architecture, and that is possible to increase the amount of information about the landmarks location by fine-tuning the encoder. The scripts/celeba_landmaks.py script contains details on helper classes and training procedure.

Loading the CelebA dataset

The CelebALandmarks dataset is an extension of the torchvision.datasets.CelebA dataset, modified to return a dictionary containing images and landmarks.

We first instantiate and visualize the training data:

[1]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')
from docs.notebooks.examples.scripts.celeba_landmarks import CelebALandmarks, LANDMARKS, REPRESENTATIONS

# CelebA download location.
DATA_DIR = '/data'

# Load the CelebA Dataset
dataset = CelebALandmarks(DATA_DIR)
[2]:
# Consider a datapoint
idx = 0
data = dataset[idx]

# We plot image and the corresponding landmarks
for name, value in data.items():
    if name == 'image':
        plt.imshow(value.permute(1,2,0))
    else:
        plt.plot(value[0],value[1], 'o', label=name)

plt.legend()
plt.axis('off');
../../_images/notebooks_examples_keypoint_info_3_0.png
[3]:
from torch.utils.data import DataLoader, Subset

num_workers = 8

# Lastly, we define a two dataloaders for training and testing purposes
trainloader = DataLoader(
    dataset,
    batch_size=32,
    num_workers=num_workers
)

# To speed up the evaluation process we evaluate mutual information only for the first 1000 pictures
testloader = DataLoader(
    Subset(dataset, range(1000)),
    batch_size=32,
    num_workers=num_workers
)
[4]:
# Each batch is composed of a dictionary that contains images and all the landmarks
batch = next(iter(trainloader))
print({name: t.shape for name, t in batch.items()})
{'image': torch.Size([32, 3, 218, 178]), 'left_eye': torch.Size([32, 2]), 'right_eye': torch.Size([32, 2]), 'nose': torch.Size([32, 2]), 'left_mouth': torch.Size([32, 2]), 'right_mouth': torch.Size([32, 2])}

Multi-Mutual Information Estimation

We use the torch_mist package to define 5x3 mutual information estimator (one for each landmark and for each encoder) that are jointly trained on the data using the MultiMIEstimator class. Each single estimator consists of an instance of DoE estimator. The aim of this tutorial is to show how to combine elements from torch_mist to easily create complex models.

Defining the Encoders

[5]:
import torch
from torch_mist.utils import freeze

from docs.notebooks.examples.scripts.celeba_landmarks import ResNet18Encoder


# We first define three encoders that will be used to create 3 pictures representations
encoders = {
    #   random_init: A frozen randomly initialized resnet18 encoder
    'random_init': freeze(ResNet18Encoder(weights=None)),
    #   pretrained: A frozen pre-trained resnet18 encoder
    'pretrained': freeze(ResNet18Encoder(weights="DEFAULT")),
    #   finetuned: A pre-trained resnet18 that will be finetuned with the gradient from the mi_estimator
    'finetuned': ResNet18Encoder(weights="DEFAULT"),
}

One estimator for each pair

[6]:
from docs.notebooks.examples.scripts.celeba_landmarks import instantiate_estimators
from torch_mist.estimators import MultiMIEstimator

# We instantiate one mutual information estimator for each pair of representation (3) and landmark (5)
estimators = instantiate_estimators()

# They are represented as a dictionary in which the key are tuple (representation, landmark) and values are the corresponding estimator
# Each estimator is defined as a difference of entropies (DoE) estimator
for (representation, landmark), estimator in estimators.items():
    print(f"({representation}, {landmark}): {estimator.__class__.__name__}")


(random_init, left_eye): DoE
(pretrained, left_eye): DoE
(finetuned, left_eye): DoE
(random_init, right_eye): DoE
(pretrained, right_eye): DoE
(finetuned, right_eye): DoE
(random_init, nose): DoE
(pretrained, nose): DoE
(finetuned, nose): DoE
(random_init, left_mouth): DoE
(pretrained, left_mouth): DoE
(finetuned, left_mouth): DoE
(random_init, right_mouth): DoE
(pretrained, right_mouth): DoE
(finetuned, right_mouth): DoE

Aggregating estimators

We can aggregate all the estimators together using the MultiMIEstimator class, which handles running multiple estimators in parallel, given that it is provided with a dictionary that contains all representations and landmarks for each batch.

Note that, depending on the specified estimator, it might be possible to obtain infomax gradient with respect to the various components (landmarks or representation). This is relevant when we are interested in training infoMax encoders, or finetuning encoders, as in this example.

In order to determine if gradient can be backpropagated, we can access the attribute infomax_gradient. Here we see that we can get gradient w.r.t. the representation but not the landmarks. This is because we used the DoE estimator, which provides gradient w.r.t. the first argument (representation) but not the second (landmarks).

[7]:
# Aggregate the estimators
multi_estimator = MultiMIEstimator(estimators)

# Check which components receive infoMax gradient
print(multi_estimator.infomax_gradient)
{'random_init': True, 'left_eye': False, 'pretrained': True, 'finetuned': True, 'right_eye': False, 'nose': False, 'left_mouth': False, 'right_mouth': False}

Applying transformations

Since we are interested in estimating the mutual information between the resnet representations and not the original images, we use the TransformedMIEstimator class to apply the 3 encoders to each image to produce 3 representations (random_init, pretrained, finetuned).

This can be done by specifying a dictionary of tuples (name_from, name_to) : transform which will be applied to the batch dictionary to produce a new dictionary in which the keys (random_init, pretrained, finetuned) are added as batch[name_to] = transform(batch[name_from])

[8]:
from torch_mist.estimators import TransformedMIEstimator

mi_estimator = TransformedMIEstimator(
    transforms_rename={
        ("image", representation_name): encoder
        for representation_name, encoder in encoders.items()
    },
    # Use the previously defined multi_estimator after applying the transforms.
    base_estimator=multi_estimator,
)

Training the estimators

The estimators can be easily trained using the train_mi_estimator utility. The training process requires a few hours on a single GPU.

In this example, we load an estimator trained and logged using Weights & Biases. Use pip install wandb to install the required dependencies or use train=True to train the model in this notebook.

[9]:
from torch_mist.utils.train import train_mi_estimator

train = False

if train:
    # Training procedure
    device = 'cuda'
    log = train_mi_estimator(
        estimator=mi_estimator,
        data=trainloader,
        max_epochs=20,
        verbose=True,
        device=device,
    )
else:
    # Loading from a checkpoint
    import wandb
    import os

    # Download the artifact corresponding to the model from wandb
    api = wandb.Api()
    artifact = api.artifact("mfederici/torch_mist/celeba_landmarks:latest")
    datadir = artifact.download()

    # Load the trained mutual information estimator
    trained_mi_estimator = torch.load(
        os.path.join(datadir, 'trained.pyt'),
        map_location=torch.device('cpu')
    )
    # Copy the state dictionary (trained_mi_estimator and mi_estimator are the same)
    mi_estimator.load_state_dict(trained_mi_estimator.state_dict())

mi_estimator.eval();
wandb: Downloading large artifact celeba_landmarks:latest, 157.12MB. 1 files...
wandb:   1 of 1 files downloaded.
Done. 0:0:0.6

Results and visualizations

Using the trained estimators, we can estimate the mutual information between all pairs of representations and landmarks.

[10]:
from torch_mist.utils import evaluate_mi

# Evaluate mutual information on a small subset (may require a few minutes on cpu)
mutual_information = evaluate_mi(
    mi_estimator,
    data=testloader
)

print(mutual_information)

{'I(random_init;left_eye)': 0.4707388657843694, 'I(pretrained;left_eye)': 0.9212275436148047, 'I(finetuned;left_eye)': 3.101716049015522, 'I(random_init;right_eye)': 0.47409993142355233, 'I(pretrained;right_eye)': 0.8931427914649248, 'I(finetuned;right_eye)': 2.9844541996717453, 'I(random_init;nose)': 0.6677500554360449, 'I(pretrained;nose)': 1.305055620148778, 'I(finetuned;nose)': 4.382897958159447, 'I(random_init;left_mouth)': 0.4304179784376174, 'I(pretrained;left_mouth)': 0.929740909487009, 'I(finetuned;left_mouth)': 3.1322446390986443, 'I(random_init;right_mouth)': 0.40875477041117847, 'I(pretrained;right_mouth)': 0.9572644671425223, 'I(finetuned;right_mouth)': 3.185377709567547}

And plot the values on a histogramm.

[11]:
import pandas as pd

results = []

# For simplicity, we convert the list of the mutual information into a pandas dataframe
for name, value in mutual_information.items():
    name = name.replace('I(', '').replace(')','')
    result = {
        'Representation': name.split(';')[0],
        'Landmark': name.split(';')[1],
        'Mutual Information [nats]': value
    }
    results.append(result)

results = pd.DataFrame(results)


sns.barplot(
    results,
    x='Landmark', hue='Representation', y='Mutual Information [nats]',
    hue_order=REPRESENTATIONS,
    order=LANDMARKS
)
sns.despine()
../../_images/notebooks_examples_keypoint_info_20_0.png

Since the DoE estimator explicitly defines variational models \(q(\text{landmark})\) and \(q(\text{landmark}|\text{representation})\), we can visualize the marginal landmark distribution and the location conditioned on the various resnet representation.

[12]:
import numpy as np

data = {key: value.unsqueeze(0) for key, value in dataset[idx].items()}

# And the nose location as the landmark
landmark_name = 'nose'
landmark_pos = data[landmark_name].data.numpy()[0]

# Encode all the representations
img = data['image']
encoded_data = mi_estimator.transform(**data)

# Prepare the axis
f, ax = plt.subplots(
    1,len(REPRESENTATIONS)+1,
    figsize=(len(REPRESENTATIONS)*4,4)
)

# Make a uniform grid of XY coordinates
x_axis = np.arange(img.shape[2])+1
y_axis = np.arange(img.shape[3])+1
Y, X = np.meshgrid(x_axis, y_axis)
XY = torch.cat([
    torch.FloatTensor(X).unsqueeze(-1),
    torch.FloatTensor(Y).unsqueeze(-1),
],-1)

# Consider the distribution q(nose) and evaluate the probability on the whole grid
q_Y = mi_estimator.base_estimator.estimators[f'random_init;{landmark_name}'].q_Y
q_y = q_Y.log_prob(XY).data

# Plot the image
ax[0].imshow(data['image'][0].permute(1,2,0))
# and the marginal probability corresponding to the nose location
ax[0].contour(X, Y, q_y, levels=np.linspace(-10,-1, 10), cmap='viridis')
# together with the actual nose location
ax[0].plot(landmark_pos[0], landmark_pos[1], 'xr', markersize=10)
ax[0].set_title(f'$\\log q$({landmark_name})')
ax[0].axis('off')

# Plot the same for each of the q(nose|representation) (one for each representation)
for i, representation_name in enumerate(REPRESENTATIONS):

    representation = encoded_data[representation_name].unsqueeze(0)
    # Consider q(nose|representation) and compute the probability on the whole grid
    q_Y_given_X = mi_estimator.base_estimator.estimators[f'{representation_name};{landmark_name}'].q_Y_given_X
    q_y_given_x = q_Y_given_X.condition(representation).log_prob(XY).data

    # Plot the image and the corresponding probability contour and ground truth location
    ax[i+1].imshow(data['image'][0].permute(1,2,0))
    ax[i+1].contour(X, Y, q_y_given_x, levels=np.linspace(-10,0, 10), cmap='viridis')
    ax[i+1].set_title(f'$\\log q$({landmark_name}|{representation_name})')
    ax[i+1].plot(landmark_pos[0], landmark_pos[1], 'xr', markersize=10)
    ax[i+1].axis('off')
../../_images/notebooks_examples_keypoint_info_22_0.png

Clearly the randomly initialized network contains less infomation regarding the location of the landmarks than a network pretrained on image classification. As expected, fine-tuning the representation to maximize the landmark-position-predictive information yields the most informative representations. This can also be seen by comparing the sharpness of the posterior distribution.

Therefore, we may infer that the pre-training procedure on ImageNet makes the network capture more information regarding the location of facial features when compared to a randomly initialized model, but most of the information regarding their position is lost in the encoding process.