Usage
The torch_mist package can be used for estimating mutual information using the pre-defined mutual information estimators included in the library or for developing, analyzing and evaluating new estimation strategies. Here we show different usage that range from basic (simple to use, limited settings) to advanced (maximum flexibility, more complex to use).
To showcase the package, we first generate samples \(x,y\sim p(x,y)\) from a bivariate normal distribution with known true mutual information.
[1]:
import torch
from torch.distributions import MultivariateNormal, Normal
# Definition of the distribution
p_XY = MultivariateNormal(
loc=torch.tensor([0., 0.]),
covariance_matrix=torch.tensor([
[1.0, 0.9],
[0.9, 1.0]
])
)
p_X = p_Y = Normal(0, 1)
# I(x;y) = H(x)+H(y)-H(x,y)
true_mi = (
p_X.entropy() + p_Y.entropy() - p_XY.entropy()
).sum().item()
print(f"True Mutual Information: {true_mi} nats")
# Generate 100000 samples
samples = p_XY.sample([100000])
all_x = samples[:,0]
all_y = samples[:,1]
True Mutual Information: 0.8303654193878174 nats
[2]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')
# Plot them
plt.scatter(
all_x[::10], all_y[::10],
marker='.',
alpha=0.05
)
plt.title("Samples")
plt.xlabel('$x$')
plt.ylabel('$y$');
Basic
We first showcase the simplest use-case. We can use estimate_mi to estimate mutual information given paired samples x and y from some joint distribution \(p(x,y)\). This function takes care of instantiating the estimator, training it and evaluating the mutual information.
[3]:
from torch_mist import estimate_mi
mi_estimate, log = estimate_mi(
estimator='js', # The mutual information estimator to use (here we use the Deep InfoMax JS estimator)
hidden_dims=[32, 32], # The hidden layers for the neural architectures (Estimator specific)
neg_samples=16, # The number of negative samples used to estimate the log partition function (Estimator specific, Default: 1)
data=(all_x, all_y), # The values of x and y
batch_size=128, # The batch size used for training (Default: 64)
max_epochs=5, # Number of maximum training epochs (Default: 10)
valid_percentage=0.1, # The percentage of data to use for validation (Default: 0.1)
evaluation_batch_size=256, # The batch size used for evaluation (Default: batch_size)
device='cpu', # The training device (Default: 'cpu')
)
print(f"Estimated Mutual Information: {mi_estimate} nats")
print(f"True Mutual Information: {true_mi} nats")
Instantiating the js estimator
Instantiating the estimator with {'hidden_dims': [32, 32], 'neg_samples': 16, 'x_dim': 1, 'y_dim': 1}
JS(
(ratio_estimator): JointCritic(
(joint_net): DenseNN(
(layers): ModuleList(
(0): Linear(in_features=2, out_features=32, bias=True)
(1): Linear(in_features=32, out_features=32, bias=True)
(2): Linear(in_features=32, out_features=1, bias=True)
)
(f): ReLU(inplace=True)
)
)
(baseline): ConstantBaseline()
(neg_samples): 16
)
Training the estimator
[Info]: patience is not specified, using patience=1 (~2% of training epochs) by default.
Best value: 0.8085264571105377
Best value: 0.8167389223847208
Best value: 0.825612061385867
Best value: 0.8273420024521744
Best value: 0.8285563580597504
[Warning]: The train procedure ended since max_epoch or max_iteration has been reached.Consider increasing the training time by specifying larger values of max_epochs or max_iterations.
Loading the weights saved at iteration 3520
Evaluating the value of Mutual Information
[Warning]: using the train_data to estimate the value of mutual information. Please specify test_data.
Estimated Mutual Information: 0.8261835863218283 nats
True Mutual Information: 0.8303654193878174 nats
The training curves are easily accessible from the train log, which is a pandas.DataFrame object
[4]:
# We can plot the estimated values by epoch
grid = sns.FacetGrid(log, col='name', hue='split', sharey=False, col_order=['loss', 'mutual_information'])
grid.axes[0,1].axhline(y=true_mi, ls='--', color='k', label='True $I(x;y)$')
grid.axes[0,1].set_ylim(0,)
grid.map(sns.lineplot, 'epoch', 'value')
plt.legend()
[4]:
<matplotlib.legend.Legend at 0x7fb880d7d1c0>
Command Line Interface
The torch_mist package provides basic functionality to estimate mutual information directly from the command line.
Given a file myfile.csv containing the columns (f1_1, f1_2, ..., f2_1, f2_2, ..., f3_1, ...), one can estimate mutual information between the f1 and f2with:
mist data=csv data.filepath=myfile.csv mi_estimator=js x_key=f1 y_key=f2
The same flags and options provided by the estimate_mi function are also available from command line.
Additionally, internal properties of the estimator can also be easily specified:
mist data=csv data.filepath=myfile.csv mi_estimator=js x_key=f1 y_key=f2 \
# Train on GPU
device=cuda \
# Use AdamW for the optimization
estimation.optimizer_class._target_=torch.optim.AdamW \
# Use ELU as nonlinearities
+mi_estimator.nonlinearity=torch.nn.ELU \
# Change the batch size to 256
params.batch_size=256
# Log on weights and bias
logger=wandb
To visualize the full list use:
mist data=csv --help
The mist CLI is implemented using hydra and the full configuration can be accessed here.
Advanced
Instead of using the estimate_mi function directly, it is possible to manually instantiate the estimators and specify additional details for the training and evaluation procedures, as described in the following sections.
Estimators
We start by defining a simple mutual information estimator based on Deep-InfoMax (JS) with joint critic architecture.
[5]:
from torch_mist.estimators import JS
from torch_mist.critic import JointCritic
from torch import nn
x_dim = y_dim = 1
# First we define a critic network that maps pairs of samples to a scalar.
# The JointCritic module takes care of concatenating the pairs of x and y, adapting the shapes when necessary
critic = JointCritic(
joint_net=nn.Sequential(
nn.Linear(x_dim+y_dim, 32),
nn.ReLU(True),
nn.Linear(32, 32),
nn.ReLU(True),
nn.Linear(32, 1)
)
)
# Then we pass it to the Jensen-Shannon estimator
estimator = JS(
critic=critic,
neg_samples=16
)
print(estimator)
JS(
(ratio_estimator): JointCritic(
(joint_net): Sequential(
(0): Linear(in_features=2, out_features=32, bias=True)
(1): ReLU(inplace=True)
(2): Linear(in_features=32, out_features=32, bias=True)
(3): ReLU(inplace=True)
(4): Linear(in_features=32, out_features=1, bias=True)
)
)
(baseline): ConstantBaseline()
(neg_samples): 16
)
Each estimator EstimatorName(...) is equipped with a factory function estimator_name(...) for easy instantiation with limited configuration.
[6]:
from torch_mist.estimators import js # The factory function for the JS estimator
estimator_2 = js(
x_dim=x_dim,
y_dim=y_dim,
hidden_dims=[32, 32],
neg_samples=16,
critic_type='joint',
)
print(estimator)
JS(
(ratio_estimator): JointCritic(
(joint_net): Sequential(
(0): Linear(in_features=2, out_features=32, bias=True)
(1): ReLU(inplace=True)
(2): Linear(in_features=32, out_features=32, bias=True)
(3): ReLU(inplace=True)
(4): Linear(in_features=32, out_features=1, bias=True)
)
)
(baseline): ConstantBaseline()
(neg_samples): 16
)
Note that the two estimators defined above are equivalent.
Training
Instead of using the paired x and y directly, we can define the dataloaders, optimizer and train the estimator with a simple training loop.
[7]:
from tqdm.autonotebook import tqdm
from torch.optim import AdamW
from torch_mist.utils.data import SampleDataset
from torch.utils.data import DataLoader
import pandas as pd
# We use a simple wrapper to make a torch.utils.data.Dataset object using the pairs of x and y
dataset = SampleDataset({'x': all_x, 'y': all_y})
# Then we make a DataLoader
dataloader = DataLoader(
dataset,
batch_size=128,
shuffle=True
)
# Use the AdamW optimizer with a learning rate of 5e-4
opt = AdamW(estimator.parameters(), lr=5e-4)
# Train for 5 epochs
n_epochs = 5
# Log to visualize the training progress
log = []
iteration = 0
# For each epoch
for epoch in range(n_epochs):
# Sample a batch of pairs
for data in tqdm(dataloader):
x, y = data['x'], data['y']
# Compute the loss
loss = estimator(x, y)
# And the corresponding estimation of Mutual Information (optional, for logging purposes)
mi = estimator.mutual_information(x, y)
# Update the parameters of the estimator
opt.zero_grad()
loss.backward()
opt.step()
# Log the loss
log.append({
'iteration': iteration,
'name':'loss',
'value': loss.item(),
})
# Log the mutual information
log.append({
'iteration': iteration,
'name':'mutual_information',
'value': mi.item(),
})
iteration += 1
pd_log = pd.DataFrame(log)
We can plot loss and estimated mutual information over time
[8]:
grid = sns.FacetGrid(pd_log, col='name', sharey=False, col_order=['loss', 'mutual_information'])
grid.map(sns.scatterplot, 'iteration', 'value', marker='.', alpha=0.1)
grid.axes[0,1].axhline(y=true_mi, ls='--', color='k', label='True $I(x;y)$')
grid.axes[0,1].set_ylim(0,)
grid.axes[0,1].legend();
Alternatively, we provide a functionality to train a mutual information estimator given either DataLoader objects or paired x and y tensors and a batch_size value. By default, in the latter case, this function creates a validation set of valid_percentage=0.1 that is used for early stopping. The two procedures reported below are equivalent.
[10]:
from torch_mist.utils import train_mi_estimator
use_train_dataloader = True
if use_train_dataloader:
log = train_mi_estimator(
estimator=estimator_2,
data=dataloader,
max_epochs=5,
valid_percentage=0,
optimizer_class=AdamW,
)
else:
log = train_mi_estimator(
estimator=estimator_2,
max_epochs=5,
data=(all_x, all_y),
batch_size=128,
valid_percentage=0.0, # Do not use a validation set
optimizer_class=AdamW,
)
[11]:
grid = sns.FacetGrid(log, col='name', sharey=False, col_order=['loss', 'mutual_information'])
grid.map(sns.scatterplot, 'iteration', 'value', marker='.', alpha=0.1)
grid.axes[0,1].axhline(y=true_mi, ls='--', color='k', label='True $I(x;y)$')
grid.axes[0,1].legend();
Evaluation
We can now evaluate the estimator on the entire dataset
[12]:
import numpy as np
mi_estimates = []
for data in tqdm(dataloader):
mi_estimates.append(estimator.mutual_information(data['x'], data['y']).item())
mi_estimate = np.mean(mi_estimates)
print(f"Estimated Mutual Information: {mi_estimate} nats")
print(f"True Mutual Information: {true_mi} nats")
Estimated Mutual Information: 0.8230861622049376 nats
True Mutual Information: 0.8303654193878174 nats
Or we can use the provided evaluate_mi utility function
[13]:
from torch_mist.utils import evaluate_mi
mi_estimate = evaluate_mi(
estimator_2,
data=dataloader
)
print(f"Estimated Mutual Information: {mi_estimate} nats")
print(f"True Mutual Information: {true_mi} nats")
Estimated Mutual Information: 0.8208410611847783 nats
True Mutual Information: 0.8303654193878174 nats
Similarly to train_mi_estimator, evaluate_mi function supports either DataLoaders or paired x and y.
[14]:
mi_estimate = evaluate_mi(
estimator_2,
data=(all_x, all_y),
batch_size=128
)
print(f"Estimated Mutual Information: {mi_estimate} nats")
print(f"True Mutual Information: {true_mi} nats")
Estimated Mutual Information: 0.822260448999722 nats
True Mutual Information: 0.8303654193878174 nats
[ ]: