Building Blocks 4 RBIG

# @title Install Packages
# %%capture
try:
    import sys, os
    from pyprojroot import here

    # spyder up to find the root
    root = here(project_files=[".here"])

    # append to path
    sys.path.append(str(root))
except ModuleNotFoundError:
    import os

    os.system("pip install chex")
    os.system("pip install git+https://github.com/IPL-UV/rbig_jax.git#egg=rbig_jax")
# jax packages
import jax
import jax.numpy as jnp
from jax.config import config

# import chex
config.update("jax_enable_x64", True)

import chex
import numpy as np
from functools import partial

KEY = jax.random.PRNGKey(123)

# logging
import tqdm
import wandb

# plot methods
import matplotlib.pyplot as plt
import seaborn as sns
import corner
from IPython.display import HTML

sns.reset_defaults()
sns.set_context(context="poster", font_scale=0.7)

%load_ext lab_black
%matplotlib inline
%load_ext autoreload
%autoreload 2
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Demo Data

from sklearn import datasets
from sklearn.preprocessing import StandardScaler

# %%wandb
# get data
seed = 123
n_samples = 5_000
n_features = 2
noise = 0.05

X, y = datasets.make_moons(n_samples=n_samples, noise=noise, random_state=seed)

data = X[:]
# plot data
fig = corner.corner(data, color="blue", hist_bin_factor=2)
../../../_images/rbig_building_blocks_5_0.png
X = jnp.array(data, dtype=np.float64)

Model

Layer I - Univariate Histogram

from rbig_jax.transforms.histogram import InitUniHistTransform

# histogram params
support_extension = 20
alpha = 1e-5
precision = 1_000
nbins = None  # init_bin_estimator("sqrt") #bins #cott"#int(np.sqrt(X.shape[0]))
jitted = True

# initialize
shape = X.shape
n_samples = shape[0]
init_hist_f = InitUniHistTransform(
    n_samples=n_samples, support_extension=support_extension
)

Init Function

# initialize bijector
X_u, hist_bijector = init_hist_f.transform_and_bijector(X)

# forward transformation
X_l1 = hist_bijector.forward(X)

# inverse transformation
X_approx = hist_bijector.inverse(X_l1)

# gradient transformation
X_l1_ldj = hist_bijector.forward_log_det_jacobian(X_l1)

# plot Transformations
fig = corner.corner(X_l1, color="red", hist_bin_factor=2)
fig.suptitle("Forward Transformation")
fig = corner.corner(X_approx, color="red", hist_bin_factor=2)
fig.suptitle("Inverse Transformation")
fig = corner.corner(X_l1_ldj, color="red", hist_bin_factor=2)
fig.suptitle("Gradient Transformation")
Text(0.5, 0.98, 'Gradient Transformation')
../../../_images/rbig_building_blocks_11_1.png ../../../_images/rbig_building_blocks_11_2.png ../../../_images/rbig_building_blocks_11_3.png
from rbig_jax.transforms.histogram import InitUniHistTransform, init_bin_estimator
from rbig_jax.transforms.kde import InitUniKDETransform, estimate_bw

# histogram params
support_extension = 20
alpha = 1e-5
precision = 1_000
nbins = None  # init_bin_estimator("sqrt") #bins #cott"#int(np.sqrt(X.shape[0]))
jitted = True

# KDE specific Transform
bw = "scott"  # estimate_bw(X.shape[0], 1, "scott")

method = "kde"
# initialize histogram transformation
if method == "histogram":

    init_hist_f = InitUniHistTransform(
        n_samples=X.shape[0],
        nbins=nbins,
        support_extension=support_extension,
        precision=precision,
        alpha=alpha,
        jitted=jitted,
    )
elif method == "kde":

    init_hist_f = InitUniKDETransform(
        shape=X.shape, support_extension=support_extension, precision=precision, bw=bw
    )
else:
    raise ValueError(f"Unrecognized transform: {method}")

Transformations

# initialize bijector
X_u, hist_bijector = init_hist_f.transform_and_bijector(X)

# forward transformation
X_l1 = hist_bijector.forward(X)

# inverse transformation
X_approx = hist_bijector.inverse(X_l1)

# gradient transformation
X_l1_ldj = hist_bijector.forward_log_det_jacobian(X_l1)

# plot Transformations
fig = corner.corner(X_l1, color="red", hist_bin_factor=2)
fig.suptitle("Forward Transformation")
fig = corner.corner(X_approx, color="red", hist_bin_factor=2)
fig.suptitle("Inverse Transformation")
fig = corner.corner(X_l1_ldj, color="red", hist_bin_factor=2)
fig.suptitle("Gradient Transformation")
Text(0.5, 0.98, 'Gradient Transformation')
../../../_images/rbig_building_blocks_15_1.png ../../../_images/rbig_building_blocks_15_2.png ../../../_images/rbig_building_blocks_15_3.png

Layer II - Inverse Gaussian CDF

from rbig_jax.transforms.inversecdf import InitInverseGaussCDF

# univariate normalization Gaussianization parameters
eps = 1e-5
jitted = True

# initialize histogram transformation
init_icdf_f = InitInverseGaussCDF(eps=eps, jitted=jitted)

Transformations

# forward with bijector
X_l2, icdf_bijector = init_icdf_f.transform_and_bijector(X_l1)

# alternatively - forward with no bijector
X_l2_ = icdf_bijector.forward(X_l1)

chex.assert_tree_all_close(X_l2_, X_l2)

# inverse transformation
X_l1_approx = icdf_bijector.inverse(X_l2)
chex.assert_tree_all_close(X_l1_approx, X_l1, rtol=1e-5)

# gradient transformation
X_l2_ldj = icdf_bijector.forward_log_det_jacobian(X_l1)
# plot Transformations
fig = corner.corner(X_l2, color="red", hist_bin_factor=2)
fig.suptitle("Forward Transformation")
fig = corner.corner(X_l1_approx, color="red", hist_bin_factor=2)
fig.suptitle("Inverse Transformation")
fig = corner.corner(X_l2_ldj, color="red", hist_bin_factor=2)
fig.suptitle("Gradient Transformation")
Text(0.5, 0.98, 'Gradient Transformation')
../../../_images/rbig_building_blocks_20_1.png ../../../_images/rbig_building_blocks_20_2.png ../../../_images/rbig_building_blocks_20_3.png

PCA Transformation

from rbig_jax.transforms.rotation import InitPCARotation

# initialize histogram transformation
init_pca_f = InitPCARotation(jitted=True)
# forward with bijector
X_l3, pca_bijector = init_pca_f.transform_and_bijector(X_l2)

# alternatively - forward with no bijector
X_l3_ = pca_bijector.forward(X_l2)

chex.assert_tree_all_close(X_l3_, X_l3)

# inverse transformation
X_l2_approx = pca_bijector.inverse(X_l3)

chex.assert_tree_all_close(X_l2_approx, X_l2, rtol=1e-3)

# gradient transformation
X_l3_ldj = pca_bijector.forward_log_det_jacobian(X_l2)

chex.assert_tree_all_close(X_l3_ldj, jnp.zeros_like(X_l3_ldj))
# plot Transformations
fig = corner.corner(X_l3, color="red", hist_bin_factor=2)
fig.suptitle("Forward Transformation")
fig = corner.corner(X_l2_approx, color="red", hist_bin_factor=2)
fig.suptitle("Inverse Transformation")
Text(0.5, 0.98, 'Inverse Transformation')
../../../_images/rbig_building_blocks_24_1.png ../../../_images/rbig_building_blocks_24_2.png

RBIG Blocks

  1. Marginal Gaussianization

  2. Random Rotation

from rbig_jax.transforms.block import RBIGBlockInit

# create a list of transformations
init_functions = [init_hist_f, init_icdf_f, init_pca_f]

# create an RBIG "block" init
rbig_block_init = RBIGBlockInit(init_functions=init_functions)
# forward and params
X_g, bijectors = rbig_block_init.forward_and_bijector(X)

# alternatively just the forward
X_g = rbig_block_init.forward(X)
fig = corner.corner(X_g, color="red", hist_bin_factor=2)
fig.suptitle("Forward Transformation")
Text(0.5, 0.98, 'Forward Transformation')
../../../_images/rbig_building_blocks_28_1.png

Forward and Inverse Transformations

So here we want to be able to chain the transformations together. We have initialized our bijectors but it would be nice to have a convenient way to loop through them calculating all of the quanties, e.g. forward, inverse, log_det_jacobian and some combination of them.

In this package, we have the BijectorChain class which gives us that flexibility.

from rbig_jax.transforms.base import BijectorChain

# create a list of BIJECTORS (not init functions)
bijectors = [hist_bijector, icdf_bijector, pca_bijector]

# create rbig_block
rbig_block = BijectorChain(bijectors=bijectors)
# forward with bijector
X_l3 = rbig_block.forward(X)

# inverse transformation
X_approx = rbig_block.inverse(X_l3)

chex.assert_tree_all_close(X_approx, X, rtol=1e-4)

# gradient transformation
X_l3_ldj = rbig_block.forward_log_det_jacobian(X)

# forward and gradient transformation
X_l3_, X_l3_ldj_ = rbig_block.forward_and_log_det(X)

chex.assert_tree_all_close(X_l3_, X_l3)
chex.assert_tree_all_close(X_l3_ldj, X_l3_ldj_)
# plot Transformations
fig = corner.corner(X_l3, color="red", hist_bin_factor=2)
fig.suptitle("Forward Transformation")
fig = corner.corner(X_approx, color="red", hist_bin_factor=2)
fig.suptitle("Inverse Transformation")
fig = corner.corner(X_l3_ldj, color="red", hist_bin_factor=2)
fig.suptitle("Gradient Transformation")
Text(0.5, 0.98, 'Gradient Transformation')
../../../_images/rbig_building_blocks_32_1.png ../../../_images/rbig_building_blocks_32_2.png ../../../_images/rbig_building_blocks_32_3.png

Multiple Layers

So it’s very evident that a single RBIG block isn’t enough. We need multiple layers. So all we need to do is loop through the init_ methods until we are satisfied. Then once we’re done, we can create another chain and check how good is our transformation.

%%time

import itertools

itercount = itertools.count(-1)

n_blocks = 20

# initialize rbig block
init_functions = [
    init_hist_f,
    init_icdf_f,
    init_pca_f
]

# initialize RBIG Init Block
rbig_block_init = RBIGBlockInit(init_functions=init_functions)

# initialize list of bijectors
bijectors = list()

# initialize transform
X_g = X.copy()

plot_steps = False

while next(itercount) < n_blocks:
    
    # fit RBIG block
    X_g, ibijector = rbig_block_init.forward_and_bijector(X_g)
    
    if plot_steps:
        fig = corner.corner(X_g, color="blue", hist_bin_factor=2)
    
    # append bijectors 
    bijectors += ibijector
CPU times: user 1min 4s, sys: 1min 4s, total: 2min 9s
Wall time: 11.7 s

Check Transformation

Now let’s check the Gaussianized data to see how well we did.

fig = corner.corner(X_g, color="red", hist_bin_factor=2)
../../../_images/rbig_building_blocks_36_0.png

This looks pretty good. So let’s see how good the inverse transformation. Again, we create a bijectorchain which will loop through all of the transformations

# create rbig_model
rbig_model = BijectorChain(bijectors=bijectors)
%%time
# forward with bijector
X_g_ = rbig_model.forward(X)

chex.assert_tree_all_close(X_g_, X_g, rtol=1e-4)

# inverse transformation
X_approx = rbig_model.inverse(X_g)

chex.assert_tree_all_close(X_approx, X, rtol=1e-2)

# gradient transformation
X_g_ldj = rbig_model.forward_log_det_jacobian(X)

# forward and gradient transformation
X_g__, X_g_ldj_ = rbig_model.forward_and_log_det(X)

chex.assert_tree_all_close(X_g__, X_g_)
chex.assert_tree_all_close(X_g_ldj, X_g_ldj_)
CPU times: user 12 s, sys: 864 ms, total: 12.9 s
Wall time: 12.8 s
# plot Transformations
fig = corner.corner(X_g, color="red", hist_bin_factor=2)
fig.suptitle("Forward Transformation")
fig = corner.corner(X_approx, color="red", hist_bin_factor=2)
fig.suptitle("Inverse Transformation")
fig = corner.corner(X_g_ldj, color="red", hist_bin_factor=2)
fig.suptitle("Gradient Transformation")
Text(0.5, 0.98, 'Gradient Transformation')
../../../_images/rbig_building_blocks_40_1.png ../../../_images/rbig_building_blocks_40_2.png ../../../_images/rbig_building_blocks_40_3.png

Gaussianization Flow

The bijector chains allow us to do some extra things like density estimation or sampling. So we can also use the GaussianizationFlow class which is exactly like the BijectorChain class but with some additional benefits like calculating log probabilities. This may seem very redundant for the iterative method, but it is very helpful for fully parameterized Gaussianization; i.e. the end result is the same but the way to find the parameters are different.

from rbig_jax.models import GaussianizationFlow
from distrax._src.distributions.normal import Normal

# initialize base distribution
base_dist = Normal(jnp.zeros((2,)), jnp.ones((2,)))

# initialize flow model
rbig_model = GaussianizationFlow(base_dist=base_dist, bijectors=bijectors)

Density Estimation

Here we will do an example of density estimation. In this example,

So here we will do an example of density estimation. The same pythn code below is equivalent.

# propagate through the chain
X_g_grid, X_ldj_grid = rbig_model.forward_and_log_det(xyinput)

# calculate log prob
base_dist = Normal(jnp.zeros((2,)), jnp.ones((2,)))

latent_prob = base_dist.log_prob(X_g_grid)

# calculate log probability
X_log_prob = latent_prob.sum(axis=1) + X_ldj_grid.sum(axis=1)

However, using the score_samples method is a lot more convenient.

# Original Density
n_samples = 1_000_000
noise = 0.05
seed = 42
X_plot, _ = datasets.make_moons(n_samples=n_samples, noise=noise, random_state=seed)
%%time

n_grid = 200
buffer = 0.01
xline = jnp.linspace(X[:, 0].min() - buffer, X[:, 0].max() + buffer, n_grid)
yline = jnp.linspace(X[:, 1].min() - buffer, X[:, 1].max() + buffer, n_grid)
xgrid, ygrid = jnp.meshgrid(xline, yline)
xyinput = jnp.concatenate([xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)], axis=1)


X_log_prob = rbig_model.score_samples(xyinput)
CPU times: user 6.91 s, sys: 384 ms, total: 7.29 s
Wall time: 5.96 s
# Original Density
from matplotlib import cm


# Estimated Density
cmap = cm.magma  # "Reds"
probs = jnp.exp(X_log_prob)


fig, ax = plt.subplots(ncols=2, figsize=(12, 5))
h = ax[0].hist2d(
    X_plot[:, 0], X_plot[:, 1], bins=512, cmap=cmap, density=True, vmin=0.0, vmax=1.0
)
ax[0].set_title("True Density")
ax[0].set(
    xlim=[X_plot[:, 0].min(), X_plot[:, 0].max()],
    ylim=[X_plot[:, 1].min(), X_plot[:, 1].max()],
)

h1 = ax[1].scatter(
    xyinput[:, 0], xyinput[:, 1], s=1, c=probs, cmap=cmap, vmin=0.0, vmax=1.0
)
ax[1].set(
    xlim=[xyinput[:, 0].min(), xyinput[:, 0].max()],
    ylim=[xyinput[:, 1].min(), xyinput[:, 1].max()],
)
# plt.colorbar(h1)
ax[1].set_title("Estimated Density")


plt.tight_layout()
plt.show()
../../../_images/rbig_building_blocks_46_0.png

Score (Negative Log-Likelihood)

nll = rbig_model.score(X)
print(f"NLL Score: {nll:.4f}")
NLL Score: 0.6803

Sampling

This is another useful application.

%%time

# number of samples
n_samples = 100_000
seed = 42

X_samples = rbig_model.sample(seed=seed, n_samples=n_samples)
CPU times: user 5.96 s, sys: 487 ms, total: 6.44 s
Wall time: 3.79 s
fig = corner.corner(X, color="blue", label="Original Data")
fig.suptitle("Original Data")
plt.show()

fig2 = corner.corner(X_samples, color="purple")
fig2.suptitle("Generated Samples")
plt.show()
../../../_images/rbig_building_blocks_51_0.png ../../../_images/rbig_building_blocks_51_1.png

Better Training

So we assumed that there would be \(20\) layers necessary in order to train the model. But how do we know that it’s the best model? This would require some stopping criteria instead of just an ad-hoc procedure.

In RBIG, we use the information reduction loss which essentially checks how much information content is being removed with each iteration. We are effectively creating a more and more independent distribution with every marginal Gaussianization + rotation. So naturally, we can simply check how much the information is being reduced between iterations. If there are no changes, we can stop.

Loss Function

We can initialize the info loss function here.

from rbig_jax.losses import init_info_loss

# define loss parameters
max_layers = 1_000
zero_tolerance = 60
p = 0.5
jitted = True

# initialize info loss function
loss = init_info_loss(
    n_samples=X.shape[0],
    max_layers=max_layers,
    zero_tolerance=zero_tolerance,
    p=p,
    jitted=jitted,
)

Training

from rbig_jax.training.iterative import train_info_loss_model

# define training params
verbose = True
n_layers_remove = 50
interval = 10

# run iterative training
X_g, rbig_model_info = train_info_loss_model(
    X=X,
    rbig_block_init=rbig_block_init,
    loss=loss,
    verbose=verbose,
    interval=interval,
    n_layers_remove=n_layers_remove,
)
Layer 10 - Cum. Info Reduction: 2.700 - Elapsed Time: 6.2950 secs
Layer 20 - Cum. Info Reduction: 2.850 - Elapsed Time: 11.0626 secs
Layer 30 - Cum. Info Reduction: 3.058 - Elapsed Time: 15.9575 secs
Layer 40 - Cum. Info Reduction: 3.216 - Elapsed Time: 20.8868 secs
Layer 50 - Cum. Info Reduction: 3.263 - Elapsed Time: 25.7276 secs
Layer 60 - Cum. Info Reduction: 3.263 - Elapsed Time: 30.5382 secs
Layer 70 - Cum. Info Reduction: 3.263 - Elapsed Time: 35.3184 secs
Layer 80 - Cum. Info Reduction: 3.277 - Elapsed Time: 40.2019 secs
Layer 90 - Cum. Info Reduction: 3.277 - Elapsed Time: 44.9938 secs
Layer 100 - Cum. Info Reduction: 3.277 - Elapsed Time: 49.7303 secs
Layer 110 - Cum. Info Reduction: 3.277 - Elapsed Time: 54.5568 secs
Layer 120 - Cum. Info Reduction: 3.320 - Elapsed Time: 59.2062 secs
Layer 130 - Cum. Info Reduction: 3.320 - Elapsed Time: 64.0455 secs
Layer 140 - Cum. Info Reduction: 3.367 - Elapsed Time: 68.8304 secs
Layer 150 - Cum. Info Reduction: 3.367 - Elapsed Time: 73.6238 secs
Layer 160 - Cum. Info Reduction: 3.367 - Elapsed Time: 78.3725 secs
Layer 170 - Cum. Info Reduction: 3.367 - Elapsed Time: 83.1724 secs
Layer 180 - Cum. Info Reduction: 3.367 - Elapsed Time: 87.9750 secs
Layer 190 - Cum. Info Reduction: 3.367 - Elapsed Time: 92.6474 secs
Converged at Layer: 192
Final Number of layers: 142 (Blocks: 47)
Total Time: 93.7292 secs

Information Reduction Evolution

fig, ax = plt.subplots()
ax.plot(rbig_model_info.info_loss, color="red")
ax.set(xlabel="Iterations", ylabel="$\Delta$ Info. Reduction")
plt.show()
../../../_images/rbig_building_blocks_58_0.png

Negative Log-Likelihood

nll = rbig_model_info.score(X)
print(f"NLL Score: {nll:.4f}")
NLL Score: 0.5901

Density Estimation (Revisited)

%%time

X_log_prob = rbig_model_info.score_samples(xyinput)
CPU times: user 32.5 s, sys: 2.57 s, total: 35.1 s
Wall time: 19.4 s
# Estimated Density
cmap = cm.magma  # "Reds"
probs = jnp.exp(X_log_prob)


fig, ax = plt.subplots(ncols=2, figsize=(12, 5))
h = ax[0].hist2d(
    X_plot[:, 0], X_plot[:, 1], bins=512, cmap=cmap, density=True, vmin=0.0, vmax=1.0
)
ax[0].set_title("True Density")
ax[0].set(
    xlim=[X_plot[:, 0].min(), X_plot[:, 0].max()],
    ylim=[X_plot[:, 1].min(), X_plot[:, 1].max()],
)

h1 = ax[1].scatter(
    xyinput[:, 0], xyinput[:, 1], s=1, c=probs, cmap=cmap, vmin=0.0, vmax=1.0
)
ax[1].set(
    xlim=[xyinput[:, 0].min(), xyinput[:, 0].max()],
    ylim=[xyinput[:, 1].min(), xyinput[:, 1].max()],
)
# plt.colorbar(h1)
ax[1].set_title("Estimated Density")


plt.tight_layout()
plt.show()
../../../_images/rbig_building_blocks_63_0.png

Sampling Revisited

%%time

# number of samples
n_samples = 100_000
seed = 42

X_samples = rbig_model_info.sample(seed=seed, n_samples=n_samples)
CPU times: user 30.6 s, sys: 2.11 s, total: 32.7 s
Wall time: 8.93 s
fig = corner.corner(X, color="blue", label="Original Data")
fig.suptitle("Original Data")
plt.show()

fig2 = corner.corner(X_samples, color="purple")
fig2.suptitle("Generated Samples")
plt.show()
../../../_images/rbig_building_blocks_66_0.png ../../../_images/rbig_building_blocks_66_1.png

Saving and Loading

Often times it would be nice to save and load models. This is useful for checkpointing (during training) and also for convenience if you’re doing research on google colab.

Fortunately, everything here are python objects, so we can easily save and load our models via pickle.

Saving

Do to the internals of python (and design choices within this library), one can only store objects. So that includes the rbig_block, the bijectors and also the rbig_model. This does not include the rbig_block_init for example because that isn’t an object, it’s a function with some local params.

import joblib

joblib.dump(rbig_model_info, "rbig_model_test.pickle")
['rbig_model_test.pickle']

Loading

Loading is straight forward!

rbig_model_loaded = joblib.load("rbig_model_test.pickle")

Simple Test

They won’t be the exact same byte-for-byte encoding. But they should give the same results either way :).

# nll for the old model
nll = rbig_model_info.score(X)
print(f"Negative Log-Likelihood: {nll:.4f}")

# nll for the loaded model
nll = rbig_model_loaded.score(X)
print(f"Negative Log-Likelihood: {nll:.4f}")
Negative Log-Likelihood: 0.5880
Negative Log-Likelihood: 0.5880